Source code: riso/distributions/ConditionalGaussian.java
1 /* RISO: an implementation of distributed belief networks.
2 * Copyright (C) 1999, Robert Dodier.
3 *
4 * This program is free software; you can redistribute it and/or modify
5 * it under the terms of the GNU General Public License as published by
6 * the Free Software Foundation; either version 2 of the License, or
7 * (at your option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * GNU General Public License for more details.
13 *
14 * You should have received a copy of the GNU General Public License
15 * along with this program; if not, write to the Free Software
16 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA, 02111-1307, USA,
17 * or visit the GNU web site, www.gnu.org.
18 */
19 package riso.distributions;
20 import java.io.*;
21 import java.rmi.*;
22 import riso.numerical.*;
23 import riso.general.*;
24
25 /** An instance of this class represents a conditional Gaussian distribution.
26 * The dependence enters only through the mean, which is a linear combination
27 * the parents plus an offset. The variance is constant.
28 *
29 * <p> Writing the marginal means of the child and parent variables,
30 * respectively, as <tt>mu(1)</tt> and <tt>mu(2)</tt>, and the respective
31 * marginal variances as <tt>Sigma(11)</tt> and <tt>Sigma(22)</tt>, and the
32 * covariance as <tt>Sigma(12)</tt>, then the conditional mean <tt>mu(1|2)</tt>
33 * and conditional variance <tt>Sigma(1|2)</tt> are as follows.
34 * <pre>
35 * mu(1|2) = mu(1) + Sigma(12) Sigma(22)^{-1} (X(2)-mu(2))
36 * Sigma(1|2) = Sigma(11) - Sigma(12) Sigma(22)^{-1} Sigma(21)
37 * </pre>
38 * where the parent variables appear as <tt>X(2)</tt>.
39 * These parameters are named as follows in the description for an object
40 * of this type:
41 * <pre>
42 * conditional-mean-multiplier == Sigma(12) Sigma(22)^{-1}
43 * conditional-mean-offset == mu(1) - Sigma(12) Sigma(22)^{-1} mu(2)
44 * conditional-variance == Sigma(1|2)
45 * </pre>
46 * In the code, these three parameters are called <tt>a_mu_1c2</tt>,
47 * <tt>b_mu_1c2</tt>, and <tt>Sigma_1c2</tt>, respectively.
48 */
49 public class ConditionalGaussian extends AbstractConditionalDistribution
50 {
51 public double[][] Sigma_1c2_inverse = null;
52 public double det_Sigma_1c2 = 0;
53
54 // These strings contain vectors and matrices; in order to parse these,
55 // we have to wait until we know how many parents there are.
56
57 String Sigma_1c2_string = null;
58 String a_mu_1c2_string = "{ 1 }";
59 String b_mu_1c2_string = "{ 0 }";
60
61 /** Offset for conditional mean calculation. The conditional mean is calculated as
62 * <tt>a_mu_1c2 * x2 + b_mu_1c2</tt>, where <tt>x2</tt> is the vector of variables
63 * on which we are conditioning.
64 */
65 public double[] b_mu_1c2;
66
67 /** Multiplier for conditional mean calculation.
68 */
69 public double[][] a_mu_1c2;
70
71 /** Covariance matrix of the conditional distribution.
72 * This matrix has number of rows and columns equal to the dimension of
73 * the child.
74 */
75 public double[][] Sigma_1c2;
76
77 /** Do-nothing constructor, so <tt>Class.forName</tt> works.
78 */
79 public ConditionalGaussian() {}
80
81 /** Return a deep copy of this object. If the matrices haven't already been
82 * parsed, parse the description strings now.
83 */
84 public Object clone() throws CloneNotSupportedException
85 {
86 try { check_matrices(); }
87 catch (Exception e) { throw new CloneNotSupportedException( this.getClass().getName()+".clone failed: "+e ); }
88
89 ConditionalGaussian copy = (ConditionalGaussian) super.clone();
90 copy.b_mu_1c2 = (b_mu_1c2 == null ? null : (double[])b_mu_1c2.clone());
91 copy.a_mu_1c2 = (a_mu_1c2 == null ? null : (double[][])a_mu_1c2.clone());
92 copy.Sigma_1c2 = (Sigma_1c2 == null ? null : (double[][])Sigma_1c2.clone());
93
94 return copy;
95 }
96
97 /** Return the number of dimensions of the child variable.
98 */
99 public int ndimensions_child()
100 {
101 try { check_matrices(); }
102 catch (Exception e) { throw new RuntimeException( "ConditionalGaussian.ndimensions_child: failed:\n\t"+e ); }
103 return a_mu_1c2.length;
104 }
105
106 /** Return the number of dimensions of the parent variables.
107 * If there is more than one parent, this is the sum of the dimensions
108 * of the parent variables.
109 */
110 public int ndimensions_parent()
111 {
112 try { check_matrices(); }
113 catch (Exception e) { throw new RuntimeException( "ConditionalGaussian.ndimensions_parent: failed:\n\t"+e ); }
114 return a_mu_1c2[0].length;
115 }
116
117 /** For a given value <code>c</code> of the parents, return a distribution
118 * which represents <code>p(x|C=c)</code>. Executing <code>get_density(c).
119 * p(x)</code> will yield the same result as <code>p(x,c)</code>.
120 */
121 public Distribution get_density( double[] c ) throws Exception
122 {
123 check_matrices();
124 double[] mu = (double[]) b_mu_1c2.clone();
125 Matrix.add( mu, Matrix.multiply( a_mu_1c2, c ) );
126 return new Gaussian( mu, Sigma_1c2 );
127 }
128
129 /** Compute the density at the point <code>x</code>.
130 * @param x Point at which to evaluate density.
131 * @param c Values of parent variables.
132 */
133 public double p( double[] x, double[] c ) throws Exception
134 {
135 check_matrices();
136 double[] mu = (double[]) b_mu_1c2.clone();
137 Matrix.add( mu, Matrix.multiply( a_mu_1c2, c ) );
138
139 if ( Sigma_1c2_inverse == null ) Sigma_1c2_inverse = Matrix.inverse( Sigma_1c2 );
140 if ( det_Sigma_1c2 == 0 ) det_Sigma_1c2 = Matrix.determinant( Sigma_1c2 );
141
142 return Gaussian.g( x, mu, Sigma_1c2_inverse, det_Sigma_1c2 );
143 }
144
145 /** Return an instance of a random variable from this distribution.
146 * @param c Parent variables.
147 */
148 public double[] random( double[] c ) throws Exception
149 {
150 check_matrices();
151 System.err.println( "ConditionalGaussian.random: VERY SLOW IMPLEMENTATION!!!" );
152 return get_density( c ).random();
153 }
154
155 /** Create a description of this distribution model as a string.
156 * This is a full description, suitable for printing, containing
157 * newlines and indents.
158 *
159 * @param leading_ws Leading whitespace string. This is written at
160 * the beginning of each line of output. Indents are produced by
161 * appending more whitespace.
162 */
163 public String format_string( String leading_ws ) throws IOException
164 {
165 int i, j;
166 String result = "", more_leading_ws = leading_ws+"\t", still_more_ws = leading_ws+"\t\t";
167
168 result += this.getClass().getName()+"\n"+leading_ws+"{"+"\n";
169
170 check_matrices();
171
172 result += more_leading_ws+"conditional-mean-multiplier";
173 if ( a_mu_1c2.length == 1 && a_mu_1c2[0].length == 1 )
174 result += " { "+a_mu_1c2[0][0]+" }\n";
175 else
176 {
177 result += "\n"+more_leading_ws+"{\n";
178 for ( i = 0; i < a_mu_1c2.length; i++ )
179 {
180 result += still_more_ws;
181 for ( j = 0; j < a_mu_1c2[i].length; j++ )
182 result += a_mu_1c2[i][j]+" ";
183 result += "\n";
184 }
185 result += more_leading_ws+"}\n";
186 }
187
188 result += more_leading_ws+"conditional-mean-offset { ";
189 for ( i = 0; i < b_mu_1c2.length; i++ )
190 result += b_mu_1c2[i]+" ";
191 result += "}\n";
192
193 result += more_leading_ws+"conditional-variance";
194 if ( Sigma_1c2.length == 1 )
195 result += " { "+Sigma_1c2[0][0]+" }\n";
196 else
197 {
198 result += "\n"+more_leading_ws+"{\n";
199 for ( i = 0; i < Sigma_1c2.length; i++ )
200 {
201 result += still_more_ws;
202 for ( j = 0; j < Sigma_1c2[i].length; j++ )
203 result += Sigma_1c2[i][j]+" ";
204 result += "\n";
205 }
206 result += more_leading_ws+"}\n";
207 }
208
209 result += leading_ws+"}\n";
210 return result;
211 }
212
213 /** Read in a <tt>ConditionalGaussian</tt> from an input stream. This is intended
214 * for input from a human-readable source; this is different from object serialization.
215 * @param st Stream tokenizer to read from.
216 * @throws IOException If the attempt to read the model fails.
217 */
218 public void pretty_input( SmarterTokenizer st ) throws IOException
219 {
220 boolean found_closing_bracket = false;
221
222 try
223 {
224 st.nextToken();
225 if ( st.ttype != '{' )
226 throw new IOException( "ConditionalGaussian.pretty_input: input doesn't have opening bracket." );
227
228 for ( st.nextToken(); !found_closing_bracket && st.ttype != StreamTokenizer.TT_EOF; st.nextToken() )
229 {
230 if ( st.ttype == StreamTokenizer.TT_WORD && st.sval.equals( "conditional-mean-multiplier" ) )
231 {
232 st.nextBlock();
233 a_mu_1c2_string = st.sval;
234 System.err.println( "CG: found a_mu_1c2_string: "+a_mu_1c2_string );
235 }
236 else if ( st.ttype == StreamTokenizer.TT_WORD && st.sval.equals( "conditional-mean-offset" ) )
237 {
238 st.nextBlock();
239 b_mu_1c2_string = st.sval;
240 System.err.println( "CG: found b_mu_1c2_string: "+b_mu_1c2_string );
241 }
242 else if ( st.ttype == StreamTokenizer.TT_WORD && st.sval.equals( "conditional-variance" ) )
243 {
244 st.nextBlock();
245 Sigma_1c2_string = st.sval;
246 System.err.println( "CG: found Sigma_1c2_string: "+Sigma_1c2_string );
247 }
248 else if ( st.ttype == '}' )
249 {
250 found_closing_bracket = true;
251 break;
252 }
253 }
254 }
255 catch (IOException e)
256 {
257 throw new IOException( "ConditionalGaussian.pretty_input: attempt to read object failed:\n"+e );
258 }
259
260 if ( ! found_closing_bracket )
261 throw new IOException( "ConditionalGaussian.pretty_input: no closing bracket on input." );
262 }
263
264 /** If vectors and matrices descriptions have not yet been parsed,
265 * do so now. If they are already parsed, do nothing.
266 *
267 * @throws IOException If the description parsing fails.
268 * @throws RemoteException If the attempt to reference parents fails.
269 */
270 public void check_matrices() throws IOException, RemoteException
271 {
272 int nchild = 1; // THIS IS THE ONLY PLACE THE CHILD DIMENSION IS RESTRICTED; CHANGE ??? !!!
273
274 if ( Sigma_1c2 == null )
275 {
276 // First figure out how many elements there are in the a_mu_1c2 description;
277 // this is equal to nchild*nparents, so set nparents = nelements/nchild.
278
279 int nelements = -2; // don't count the left and right parentheses.
280 SmarterTokenizer st = new SmarterTokenizer( new StringReader(a_mu_1c2_string) );
281 for ( st.nextToken(); st.ttype != StreamTokenizer.TT_EOF; st.nextToken() )
282 ++nelements;
283 int nparents = nelements/nchild;
284
285 Sigma_1c2 = parse_matrix( Sigma_1c2_string, nchild, nchild );
286 a_mu_1c2 = parse_matrix( a_mu_1c2_string, nchild, nparents );
287 b_mu_1c2 = parse_vector( b_mu_1c2_string, nchild );
288 }
289 }
290
291 static double[][] parse_matrix( String s, int nrows, int ncols ) throws IOException
292 {
293 double[][] A = new double[nrows][ncols];
294 SmarterTokenizer st = new SmarterTokenizer( new StringReader( s ) );
295
296 st.nextToken();
297 if ( st.ttype != '{' )
298 throw new IOException( "ConditionalGaussian.parse_matrix: input doesn't have opening bracket." );
299
300 for ( int i = 0; i < nrows; i++ )
301 for ( int j = 0; j < ncols; j++ )
302 {
303 st.nextToken();
304 A[i][j] = Double.parseDouble( st.sval );
305 }
306
307 st.nextToken();
308 if ( st.ttype != '}' )
309 throw new IOException( "ConditionalGaussian.parse_matrix: input doesn't have closing bracket." );
310
311 return A;
312 }
313
314 static double[] parse_vector( String s, int n ) throws IOException
315 {
316 double[] x = new double[n];
317 SmarterTokenizer st = new SmarterTokenizer( new StringReader( s ) );
318
319 st.nextToken();
320 if ( st.ttype != '{' )
321 throw new IOException( "ConditionalGaussian.parse_matrix: input doesn't have opening bracket." );
322
323 for ( int j = 0; j < n; j++ )
324 {
325 st.nextToken();
326 x[j] = Double.parseDouble( st.sval );
327 }
328
329 st.nextToken();
330 if ( st.ttype != '}' )
331 throw new IOException( "ConditionalGaussian.parse_matrix: input doesn't have closing bracket." );
332
333 return x;
334 }
335 }