Docjar: A Java Source and Docuemnt Enginecom.*    java.*    javax.*    org.*    all    new    plug-in

Quick Search    Search Deep

Source code: AI/NeuralNetworks/BackErrorPropagation.java


1   /* BackErrorPropagation.java */
2   
3   package AI.NeuralNetworks;
4   
5   import java.util.*;
6   import java.io.*;
7   import java.net.*;
8   
9   /** 
10    This Class is a nerual network with a multiple layer feed forward architecture where
11    the back error propagation could be run
12  */
13  public class BackErrorPropagation extends FeedForwardNetwork implements NeuralNetworkTeacher{
14    protected static final String COMMENT_TOKEN = "#";
15    protected static final String INPUT_TOKEN = "Input:";
16    protected static final String OUTPUT_TOKEN = "Output:";
17    protected static final String VECTOR_TOKEN = "Pattern:";
18    protected static final String END_TOKEN = "end";
19  
20  
21    protected static float DEFAULT_LEARNING_RATE = (float) 0.5;
22    protected static float DEFAULT_MOMENTUM_RATE = (float) 0.0;
23    protected static long DEFAULT_MAXIMUM_EPOCH = 100000;
24    protected static float DEFAULT_STOP_ERROR = (float) 0.0;
25    
26  //  protected FeedForwardNetwork network;
27    protected float learningRate;
28    protected float momentumRate;
29    protected long maximumEpoch;
30    protected float stopError;
31    protected Notify notifyError;
32    /**
33     Creates a new BackErrorPropagation instance
34     */
35    public BackErrorPropagation(FeedForwardNetwork network, float learningRate,
36                                float momentumRate, long maximumEpoch, float stopError){
37      ListIterator originalNet;
38      
39      layers = new LinkedList();
40      
41      originalNet = network.layers.listIterator();
42      layers.add(new FeedForwardInputLayer((FeedForwardInputLayer) originalNet.next()));
43      while(originalNet.hasNext()){
44        layers.add(new FeedForwardWeightedTrainingLayer((FeedForwardWeightedLayer) originalNet.next()));
45      }
46      this.learningRate = learningRate;
47      this.momentumRate = momentumRate;
48      this.maximumEpoch = maximumEpoch;
49      this.stopError = stopError;
50    }
51  
52    /**
53     Creates a new BackErrorPropagation instance
54     */
55    public BackErrorPropagation(FeedForwardNetwork network){    
56      this(network, DEFAULT_LEARNING_RATE, DEFAULT_MOMENTUM_RATE, DEFAULT_MAXIMUM_EPOCH, DEFAULT_STOP_ERROR);
57    }  
58    
59    /** this construcctor should not be used it is here for inheriterence only */
60    protected BackErrorPropagation(){}
61    
62    /**
63     Trains the network with the  given patterns
64     @param patterns are the pairs of input-expected output that the network must learn
65     */
66    public float trainNetwork(LinkedList trainingPatterns)throws UnexpectedInputArraySizeException{
67      ListIterator patternList;
68      PatternPair pattern;
69      float[] outputError;
70      int epoch=0;
71      float error = (float) 1.0;
72      
73      while((epoch<maximumEpoch) && (error>stopError)){
74        patternList = trainingPatterns.listIterator(0);
75      
76        resetDeltaWeights();
77        error = 0;
78  
79  //System.out.println(this.toString());    
80        
81        while(patternList.hasNext()){
82          pattern = (PatternPair) patternList.next();
83        
84          adjustDeltaWeights(pattern.input(), pattern.output());
85  
86          outputError = ((FeedForwardWeightedTrainingLayer) layers.getLast()).getOutputError();
87          for (int e=0; e<outputError.length; e++){
88  //          error += Math.abs(outputError[e]);
89            error += outputError[e] * outputError[e];
90          }
91        }
92  //      outputError = ((FeedForwardWeightedTrainingLayer) layers.getLast()).getOutputError();
93        error = (float) Math.sqrt(Math.abs(error / (trainingPatterns.size() * outputSize())));
94  
95        if (notifyError != null) 
96          notifyError.notify((Object) new Float(error) );
97  /*debug*///AI.Test.TestBP.printArray(null,"error = "+ error);
98  
99        addDeltaWeights();      
100       
101       epoch++;
102     }
103     
104     return error;
105   }
106   
107   /**
108     Sets the acumulated delta weights to 0
109   */
110   public void resetDeltaWeights(){
111     ListIterator layer;
112     
113     layer = layers.listIterator(1); //the first layer does not have weights
114     
115     while(layer.hasNext()){
116       ((FeedForwardWeightedTrainingLayer) layer.next()).resetDeltaWeights();
117     }
118   }
119   
120   /**
121     Calculates the delta weights for this pattern and adds it to the acumlated deltas
122     @param input is the input pattern
123     @param expectedOutput is the expected output patternfor this input pattern
124   */
125   public void adjustDeltaWeights(float[] input, float[] expectedOutput)throws UnexpectedInputArraySizeException{
126     ListIterator layer;
127     ListIterator previousLayer;
128     ListIterator nextLayer;
129     float[] output;
130     float  error;
131     
132     output = runVector(input);
133     layer = layers.listIterator(layers.size());
134     nextLayer = layers.listIterator(layers.size());
135     previousLayer = layers.listIterator(layers.size()-1);
136     
137     ((FeedForwardWeightedTrainingLayer) layer.previous()).adjustDeltaWeights(expectedOutput, (FeedForwardLayer)previousLayer.previous());
138     while(layer.previousIndex() > 0){
139       ((FeedForwardWeightedTrainingLayer)layer.previous()).adjustDeltaWeights((FeedForwardLayer)previousLayer.previous(),(FeedForwardWeightedTrainingLayer)nextLayer.previous());
140     }
141     
142   }
143   
144   /**
145     Adds the deltaWeights to the weights
146   */
147   public void addDeltaWeights(){
148     ListIterator layer;
149     
150     layer = layers.listIterator(1); //the first layer does not have weights
151     
152     while(layer.hasNext()){
153       ((FeedForwardWeightedTrainingLayer) layer.next()).addDeltaWeights(learningRate);
154     }
155   }
156   
157   /**
158     Sets an implementation of Notify, that will be notified of the error of each epoch
159   */
160   public void setNotifyError(Notify notifyError) {
161     this.notifyError = notifyError;
162   }
163   
164   /**
165     Sets the learnning rate
166   */
167   public void setLearnningRate(float newRate) {
168     this.learningRate = newRate;
169   }
170   
171   /**
172     Sets the maximum epoch
173   */
174   public void setMaximumEpoch(int newMaxEpoch) {
175     this.maximumEpoch = newMaxEpoch;
176   }
177 
178   /**
179     Sets the stop error
180   */
181   public void setStopError(float newStopError) {
182     this.stopError = newStopError;
183   }
184 
185         public static LinkedList getVectorsFormFile(String fileName){
186 
187 //    LinkedList patterns = new LinkedList();
188     BufferedReader input;
189 /*    StringTokenizer st;
190     String token;
191     int insize;
192     int outsize;
193     float inVector[];  
194     float outVector[];
195 */    
196     try{
197       input = new BufferedReader(new FileReader(fileName));
198                         return getVectors(input);
199 /*      
200 //      System.out.println("get input token ");
201       st = getNextStringTokenizer(input, INPUT_TOKEN);
202       insize = Integer.parseInt(st.nextToken());
203 //      System.out.println("get output token ");
204       st = getNextStringTokenizer(input, OUTPUT_TOKEN);
205       outsize = Integer.parseInt(st.nextToken());
206       //br.readLine();  //resto de la linea   
207 //      System.out.println("geted output token ");
208       
209       while(true){
210         inVector = new float[insize];
211         outVector = new float[outsize];
212         
213 //        System.out.println("get pattern token ");
214         st = getNextStringTokenizer(input, VECTOR_TOKEN);
215         for(int a=0; a<insize ; a++){
216           inVector[a] = Float.parseFloat(st.nextToken());
217         }
218         for(int a=0; a<outsize; a++){
219           outVector[a] = Float.parseFloat(st.nextToken());
220         }
221         patterns.add(new AI.NeuralNetworks.PatternPair(inVector, outVector));
222       }
223     } catch (EOFException eof) {
224       System.out.println("todo mal " );
225     //  return patterns;
226     }/* catch (NoSuchElementException nsee) {
227       System.out.println("Todo bien ");
228       return patterns;*/
229     } catch (IOException ioe) {
230       System.out.println("The file " + fileName + " does not contain a valid input patterns");
231       ioe.printStackTrace();
232       return null;
233     }
234   }
235                 
236   public static LinkedList getVectors(BufferedReader input){
237 
238     LinkedList patterns = new LinkedList();
239 //    BufferedReader input;
240     StringTokenizer st;
241     String token;
242     int insize;
243     int outsize;
244     float inVector[];  
245     float outVector[];
246     
247     try{
248 //      input = new BufferedReader(new FileReader(fileName));
249       
250 //      System.out.println("get input token ");
251       st = getNextStringTokenizer(input, INPUT_TOKEN);
252       insize = Integer.parseInt(st.nextToken());
253 //      System.out.println("get output token ");
254       st = getNextStringTokenizer(input, OUTPUT_TOKEN);
255       outsize = Integer.parseInt(st.nextToken());
256       //br.readLine();  //resto de la linea   
257 //      System.out.println("geted output token ");
258       
259       while(true){
260         inVector = new float[insize];
261         outVector = new float[outsize];
262         
263 //        System.out.println("get pattern token ");
264         st = getNextStringTokenizer(input, VECTOR_TOKEN);
265         for(int a=0; a<insize ; a++){
266           inVector[a] = Float.parseFloat(st.nextToken());
267         }
268         for(int a=0; a<outsize; a++){
269           outVector[a] = Float.parseFloat(st.nextToken());
270         }
271         patterns.add(new AI.NeuralNetworks.PatternPair(inVector, outVector));
272       }
273     } catch (EOFException eof) {
274       System.out.println("todo bien " /*+ patterns*/);
275       return patterns;
276     } catch (NoSuchElementException nsee) {
277       System.out.println("Todo bien " /*+ patterns*/);
278       return patterns;
279     } catch (IOException ioe) {
280       System.out.println(/*"The file " + fileName + */" does not contain a valid input patterns");
281       ioe.printStackTrace();
282       return null;
283     }
284   }
285   
286   /**
287     converts a string composed by tokens that could be converted to int to an array of int
288   */
289   public static int[] parseStructure(String stringStructure){
290     StringTokenizer st = new StringTokenizer(stringStructure);
291     int numberLayers = st.countTokens();
292     int structure[] = new int[numberLayers];
293     for(int l=0; l<numberLayers; l++){
294       structure[l] = Integer.parseInt(st.nextToken());
295     }
296     return structure;
297   }
298   
299   private static StringTokenizer getNextStringTokenizer(BufferedReader input, String token) throws IOException, NoSuchElementException{
300     StringTokenizer line;
301     boolean end = true;
302     String tkn = "";
303     
304 //    System.out.println("called");
305     
306     try {
307 //      String s = input.readLine();
308       String s = " ";
309 //      line = new StringTokenizer(input.readLine());
310 //        System.out.println(s);
311       line = new StringTokenizer(s);
312        while(!line.hasMoreTokens() || ((!(tkn = line.nextToken()).equals(token)) && (!(end=tkn.equals(END_TOKEN))) ) ){
313 //        System.out.println(tkn);
314         line = new StringTokenizer(s=input.readLine());
315 //        System.out.println(s);
316       }
317       if (end) throw new EOFException("End of file");
318 //      System.out.println(s);
319        return line;
320      } catch (NoSuchElementException nsee) {
321       //throw new CoruptedFileException("This file is not a Feed Forewar Nerual Network file");
322       throw nsee;
323     }
324   }
325   
326 }