| Home >> All >> AI >> [ NeuralNetworks Javadoc ] |
Source code: AI/NeuralNetworks/DistributedBackErrorPropagation.java
1 /* DistributedBackErrorPropagation.java */ 2 3 package AI.NeuralNetworks; 4 5 import java.util.*; 6 import java.net.*; 7 import java.io.*; 8 //import MPI.*; 9 10 /** 11 This Class is a nerual network with a multiple layer feed forward architecture where 12 a distributed version of the back error propagation algorithm could be run 13 */ 14 public class DistributedBackErrorPropagation extends BackErrorPropagation implements NeuralNetworkTeacher{ 15 // protected boolean master; 16 protected Set hosts; 17 protected Socket sock; 18 19 protected BufferedWriter bw; 20 21 // protected int rootProcess = 0; 22 // protected Comm communicator; 23 24 // public static int ROOT = 0; 25 // public static String LOAD_STRUCTURE = ""; 26 27 28 /** 29 Creates a new DistributedBackErrorPropagation instance 30 */ 31 public DistributedBackErrorPropagation(FeedForwardNetwork network, BufferedWriter bw){ 32 ListIterator originalNet; 33 34 layers = new LinkedList(); 35 36 originalNet = network.layers.listIterator(); 37 layers.add(new FeedForwardInputLayer((FeedForwardInputLayer) originalNet.next())); 38 while(originalNet.hasNext()){ 39 layers.add(new DistributedFeedForwardWeightedTrainingLayer((FeedForwardWeightedLayer) originalNet.next() /*, communicator, rootProcess*/ )); 40 } 41 this.bw = bw; 42 } 43 44 /** 45 Creates a new DistributedBackErrorPropagation instance 46 */ 47 public DistributedBackErrorPropagation(FeedForwardNetwork network){ 48 this(network, null); 49 } 50 /** 51 Trains the network with the given patterns 52 @param patterns are the pairs of input-expected output that the network must learn 53 */ 54 public float trainNetwork(LinkedList trainingPatterns) throws UnexpectedInputArraySizeException{ 55 ListIterator patternList; 56 PatternPair pattern; 57 float[] outputError; 58 // int epoch=0; 59 float error = 0; 60 // float[] tempError = new float[1]; 61 62 patternList = trainingPatterns.listIterator(0); 63 64 resetDeltaWeights(); 65 error = 0; 66 67 //System.out.println(this.toString()); 68 69 while(patternList.hasNext()){ 70 pattern = (PatternPair) patternList.next(); 71 72 // System.out.println("\t"+r+"\tepoch: "+epoch+"\tpattern: "+niceArray(pattern.input())); //debug 73 adjustDeltaWeights(pattern.input(), pattern.output()); 74 75 outputError = ((FeedForwardWeightedTrainingLayer) layers.getLast()).getOutputError(); 76 for (int e=0; e<outputError.length; e++){ 77 error += outputError[e] * outputError[e]; 78 //System.out.println("errorAcum " + error ); 79 } 80 } 81 // outputError = ((FeedForwardWeightedTrainingLayer) layers.getLast()).getError(); 82 // tempError[0] = error; 83 // tempError = communicator.AllReduceSumFloat(tempError, rootProcess); 84 // error = (float) Math.sqrt(Math.abs(tempError[0] / (trainingPatterns.size() * outputError.length))); 85 86 // addDeltaWeights(); 87 88 sendError(error); 89 //System.out.println("errorSend " + error ); 90 sendDeltaWeights(); 91 92 return error; 93 } 94 95 public void sendError(float error){ 96 try{ 97 /**/ 98 int errorBits = Float.floatToIntBits(error); 99 bw.write( errorBits & 0x000000ff); 100 bw.write((errorBits & 0x0000ff00) >>>8); 101 bw.write((errorBits & 0x00ff0000) >>>16); 102 bw.write((errorBits & 0xff000000) >>>24); 103 //System.out.println("sendError = "+ error);//debug 104 /**/ 105 //bw.write(""+error); // cambiar aca que manda el error como String 106 //bw.newLine(); 107 bw.flush(); 108 } catch (IOException ioe) { 109 ioe.printStackTrace(); 110 } 111 } 112 113 public void sendDeltaWeights(){// cambiar aca que manda los pesos como String 114 ListIterator layer; 115 116 layer = layers.listIterator(1); //the first layer does not have weights 117 118 while(layer.hasNext()){ 119 ((DistributedFeedForwardWeightedTrainingLayer) layer.next()).sendDeltaWeights(bw); 120 } 121 } 122 123 /** 124 Adds the all deltaWeights to the weights 125 */ 126 public void addDeltaWeights(){ 127 ListIterator layer; 128 129 layer = layers.listIterator(1); //the first layer does not have weights 130 131 while(layer.hasNext()){ 132 ((DistributedFeedForwardWeightedTrainingLayer) layer.next()).addDeltaWeights(learningRate); 133 } 134 } 135 136 /** 137 adds the delta weights form br to the local delta weights 138 */ 139 public void parseDeltaWeights(BufferedReader br){ 140 ListIterator layer; 141 142 layer = layers.listIterator(1); //the first layer does not have weights 143 144 while(layer.hasNext()){ 145 ((DistributedFeedForwardWeightedTrainingLayer) layer.next()).parseDeltaWeights(br); 146 } 147 } 148 149 150 151 static private String niceArray(float array[]){ 152 String s = "[ "; 153 for(int i=0; i<array.length; i++){ 154 s += array[i] + " "; 155 } 156 s+= "]"; 157 return s; 158 } 159 static private String niceArray(int array[]){ 160 String s = "[ "; 161 for(int i=0; i<array.length; i++){ 162 s += array[i] + " "; 163 } 164 s+= "]"; 165 return s; 166 } 167 168 static private void showList(LinkedList pat){ //debug 169 ListIterator li = pat.listIterator(0); 170 171 while(li.hasNext()){ 172 System.out.println(((PatternPair) li.next()).toString()); 173 } 174 175 } 176 177 }