Source code: AI/NeuralNetworks/BackErrorPropagation.java
1 /* BackErrorPropagation.java */
2
3 package AI.NeuralNetworks;
4
5 import java.util.*;
6 import java.io.*;
7 import java.net.*;
8
9 /**
10 This Class is a nerual network with a multiple layer feed forward architecture where
11 the back error propagation could be run
12 */
13 public class BackErrorPropagation extends FeedForwardNetwork implements NeuralNetworkTeacher{
14 protected static final String COMMENT_TOKEN = "#";
15 protected static final String INPUT_TOKEN = "Input:";
16 protected static final String OUTPUT_TOKEN = "Output:";
17 protected static final String VECTOR_TOKEN = "Pattern:";
18 protected static final String END_TOKEN = "end";
19
20
21 protected static float DEFAULT_LEARNING_RATE = (float) 0.5;
22 protected static float DEFAULT_MOMENTUM_RATE = (float) 0.0;
23 protected static long DEFAULT_MAXIMUM_EPOCH = 100000;
24 protected static float DEFAULT_STOP_ERROR = (float) 0.0;
25
26 // protected FeedForwardNetwork network;
27 protected float learningRate;
28 protected float momentumRate;
29 protected long maximumEpoch;
30 protected float stopError;
31 protected Notify notifyError;
32 /**
33 Creates a new BackErrorPropagation instance
34 */
35 public BackErrorPropagation(FeedForwardNetwork network, float learningRate,
36 float momentumRate, long maximumEpoch, float stopError){
37 ListIterator originalNet;
38
39 layers = new LinkedList();
40
41 originalNet = network.layers.listIterator();
42 layers.add(new FeedForwardInputLayer((FeedForwardInputLayer) originalNet.next()));
43 while(originalNet.hasNext()){
44 layers.add(new FeedForwardWeightedTrainingLayer((FeedForwardWeightedLayer) originalNet.next()));
45 }
46 this.learningRate = learningRate;
47 this.momentumRate = momentumRate;
48 this.maximumEpoch = maximumEpoch;
49 this.stopError = stopError;
50 }
51
52 /**
53 Creates a new BackErrorPropagation instance
54 */
55 public BackErrorPropagation(FeedForwardNetwork network){
56 this(network, DEFAULT_LEARNING_RATE, DEFAULT_MOMENTUM_RATE, DEFAULT_MAXIMUM_EPOCH, DEFAULT_STOP_ERROR);
57 }
58
59 /** this construcctor should not be used it is here for inheriterence only */
60 protected BackErrorPropagation(){}
61
62 /**
63 Trains the network with the given patterns
64 @param patterns are the pairs of input-expected output that the network must learn
65 */
66 public float trainNetwork(LinkedList trainingPatterns)throws UnexpectedInputArraySizeException{
67 ListIterator patternList;
68 PatternPair pattern;
69 float[] outputError;
70 int epoch=0;
71 float error = (float) 1.0;
72
73 while((epoch<maximumEpoch) && (error>stopError)){
74 patternList = trainingPatterns.listIterator(0);
75
76 resetDeltaWeights();
77 error = 0;
78
79 //System.out.println(this.toString());
80
81 while(patternList.hasNext()){
82 pattern = (PatternPair) patternList.next();
83
84 adjustDeltaWeights(pattern.input(), pattern.output());
85
86 outputError = ((FeedForwardWeightedTrainingLayer) layers.getLast()).getOutputError();
87 for (int e=0; e<outputError.length; e++){
88 // error += Math.abs(outputError[e]);
89 error += outputError[e] * outputError[e];
90 }
91 }
92 // outputError = ((FeedForwardWeightedTrainingLayer) layers.getLast()).getOutputError();
93 error = (float) Math.sqrt(Math.abs(error / (trainingPatterns.size() * outputSize())));
94
95 if (notifyError != null)
96 notifyError.notify((Object) new Float(error) );
97 /*debug*///AI.Test.TestBP.printArray(null,"error = "+ error);
98
99 addDeltaWeights();
100
101 epoch++;
102 }
103
104 return error;
105 }
106
107 /**
108 Sets the acumulated delta weights to 0
109 */
110 public void resetDeltaWeights(){
111 ListIterator layer;
112
113 layer = layers.listIterator(1); //the first layer does not have weights
114
115 while(layer.hasNext()){
116 ((FeedForwardWeightedTrainingLayer) layer.next()).resetDeltaWeights();
117 }
118 }
119
120 /**
121 Calculates the delta weights for this pattern and adds it to the acumlated deltas
122 @param input is the input pattern
123 @param expectedOutput is the expected output patternfor this input pattern
124 */
125 public void adjustDeltaWeights(float[] input, float[] expectedOutput)throws UnexpectedInputArraySizeException{
126 ListIterator layer;
127 ListIterator previousLayer;
128 ListIterator nextLayer;
129 float[] output;
130 float error;
131
132 output = runVector(input);
133 layer = layers.listIterator(layers.size());
134 nextLayer = layers.listIterator(layers.size());
135 previousLayer = layers.listIterator(layers.size()-1);
136
137 ((FeedForwardWeightedTrainingLayer) layer.previous()).adjustDeltaWeights(expectedOutput, (FeedForwardLayer)previousLayer.previous());
138 while(layer.previousIndex() > 0){
139 ((FeedForwardWeightedTrainingLayer)layer.previous()).adjustDeltaWeights((FeedForwardLayer)previousLayer.previous(),(FeedForwardWeightedTrainingLayer)nextLayer.previous());
140 }
141
142 }
143
144 /**
145 Adds the deltaWeights to the weights
146 */
147 public void addDeltaWeights(){
148 ListIterator layer;
149
150 layer = layers.listIterator(1); //the first layer does not have weights
151
152 while(layer.hasNext()){
153 ((FeedForwardWeightedTrainingLayer) layer.next()).addDeltaWeights(learningRate);
154 }
155 }
156
157 /**
158 Sets an implementation of Notify, that will be notified of the error of each epoch
159 */
160 public void setNotifyError(Notify notifyError) {
161 this.notifyError = notifyError;
162 }
163
164 /**
165 Sets the learnning rate
166 */
167 public void setLearnningRate(float newRate) {
168 this.learningRate = newRate;
169 }
170
171 /**
172 Sets the maximum epoch
173 */
174 public void setMaximumEpoch(int newMaxEpoch) {
175 this.maximumEpoch = newMaxEpoch;
176 }
177
178 /**
179 Sets the stop error
180 */
181 public void setStopError(float newStopError) {
182 this.stopError = newStopError;
183 }
184
185 public static LinkedList getVectorsFormFile(String fileName){
186
187 // LinkedList patterns = new LinkedList();
188 BufferedReader input;
189 /* StringTokenizer st;
190 String token;
191 int insize;
192 int outsize;
193 float inVector[];
194 float outVector[];
195 */
196 try{
197 input = new BufferedReader(new FileReader(fileName));
198 return getVectors(input);
199 /*
200 // System.out.println("get input token ");
201 st = getNextStringTokenizer(input, INPUT_TOKEN);
202 insize = Integer.parseInt(st.nextToken());
203 // System.out.println("get output token ");
204 st = getNextStringTokenizer(input, OUTPUT_TOKEN);
205 outsize = Integer.parseInt(st.nextToken());
206 //br.readLine(); //resto de la linea
207 // System.out.println("geted output token ");
208
209 while(true){
210 inVector = new float[insize];
211 outVector = new float[outsize];
212
213 // System.out.println("get pattern token ");
214 st = getNextStringTokenizer(input, VECTOR_TOKEN);
215 for(int a=0; a<insize ; a++){
216 inVector[a] = Float.parseFloat(st.nextToken());
217 }
218 for(int a=0; a<outsize; a++){
219 outVector[a] = Float.parseFloat(st.nextToken());
220 }
221 patterns.add(new AI.NeuralNetworks.PatternPair(inVector, outVector));
222 }
223 } catch (EOFException eof) {
224 System.out.println("todo mal " );
225 // return patterns;
226 }/* catch (NoSuchElementException nsee) {
227 System.out.println("Todo bien ");
228 return patterns;*/
229 } catch (IOException ioe) {
230 System.out.println("The file " + fileName + " does not contain a valid input patterns");
231 ioe.printStackTrace();
232 return null;
233 }
234 }
235
236 public static LinkedList getVectors(BufferedReader input){
237
238 LinkedList patterns = new LinkedList();
239 // BufferedReader input;
240 StringTokenizer st;
241 String token;
242 int insize;
243 int outsize;
244 float inVector[];
245 float outVector[];
246
247 try{
248 // input = new BufferedReader(new FileReader(fileName));
249
250 // System.out.println("get input token ");
251 st = getNextStringTokenizer(input, INPUT_TOKEN);
252 insize = Integer.parseInt(st.nextToken());
253 // System.out.println("get output token ");
254 st = getNextStringTokenizer(input, OUTPUT_TOKEN);
255 outsize = Integer.parseInt(st.nextToken());
256 //br.readLine(); //resto de la linea
257 // System.out.println("geted output token ");
258
259 while(true){
260 inVector = new float[insize];
261 outVector = new float[outsize];
262
263 // System.out.println("get pattern token ");
264 st = getNextStringTokenizer(input, VECTOR_TOKEN);
265 for(int a=0; a<insize ; a++){
266 inVector[a] = Float.parseFloat(st.nextToken());
267 }
268 for(int a=0; a<outsize; a++){
269 outVector[a] = Float.parseFloat(st.nextToken());
270 }
271 patterns.add(new AI.NeuralNetworks.PatternPair(inVector, outVector));
272 }
273 } catch (EOFException eof) {
274 System.out.println("todo bien " /*+ patterns*/);
275 return patterns;
276 } catch (NoSuchElementException nsee) {
277 System.out.println("Todo bien " /*+ patterns*/);
278 return patterns;
279 } catch (IOException ioe) {
280 System.out.println(/*"The file " + fileName + */" does not contain a valid input patterns");
281 ioe.printStackTrace();
282 return null;
283 }
284 }
285
286 /**
287 converts a string composed by tokens that could be converted to int to an array of int
288 */
289 public static int[] parseStructure(String stringStructure){
290 StringTokenizer st = new StringTokenizer(stringStructure);
291 int numberLayers = st.countTokens();
292 int structure[] = new int[numberLayers];
293 for(int l=0; l<numberLayers; l++){
294 structure[l] = Integer.parseInt(st.nextToken());
295 }
296 return structure;
297 }
298
299 private static StringTokenizer getNextStringTokenizer(BufferedReader input, String token) throws IOException, NoSuchElementException{
300 StringTokenizer line;
301 boolean end = true;
302 String tkn = "";
303
304 // System.out.println("called");
305
306 try {
307 // String s = input.readLine();
308 String s = " ";
309 // line = new StringTokenizer(input.readLine());
310 // System.out.println(s);
311 line = new StringTokenizer(s);
312 while(!line.hasMoreTokens() || ((!(tkn = line.nextToken()).equals(token)) && (!(end=tkn.equals(END_TOKEN))) ) ){
313 // System.out.println(tkn);
314 line = new StringTokenizer(s=input.readLine());
315 // System.out.println(s);
316 }
317 if (end) throw new EOFException("End of file");
318 // System.out.println(s);
319 return line;
320 } catch (NoSuchElementException nsee) {
321 //throw new CoruptedFileException("This file is not a Feed Forewar Nerual Network file");
322 throw nsee;
323 }
324 }
325
326 }