
# file "Rtrans2" -- a transdimensional Metropolis-Hastings algorithm

# THE DATA:

# 40 points from N(-1,1):
xnorm =
c(0.75745087478711, -3.67833850920761, 0.773765294923852, -2.5166350926628, 
-0.866664528466648, -1.63695628038082, -2.69436694486401, -1.17670482134698, 
-0.93268038000603, -2.68863707210148, -2.55723562317876, -0.96943475222958, 
-0.502247844160163, -0.95460405243514, -1.35246808601944, 0.430539874991081, 
-0.216421013142778, -0.0735422349319686, 0.929234311847512, -2.55736988978725, 
0.659802037906305, 0.487098508940152, -2.25146570024105, -1.54900454926781, 
-0.0399614473021575, -1.27424418773441, -1.39537014482172, -0.378208795173341, 
-1.43844136511807, -2.03013360311522, -2.60644108937020, -1.03736744149929, 
-1.06954070965917, -1.97334401699724, 0.34551263753373, -1.00174729815340, 
0.227300814797052, -1.21912077204945, -2.36256531047837, -1.07478823528914
)

# Combine that together in a mixture with k=8:
y = c(xnorm-5,xnorm,xnorm+2,xnorm+5,xnorm+9,xnorm+12,xnorm+20,xnorm+22)
# y = 2 * y
J = length(y)

logg = function(k,avect) {
   if ( (k<1) || (k>kmax) ) {
     # print("returning!");
     return(-Inf);
   }
   logresult = dpois(k-1,2,log=TRUE) - (k/2)*log(2*pi) -
   					0.5 * sum((avect[1:k])^2);
   for (j in 1:J) {
     tmpsum = 0
     for (i in 1:k)
       tmpsum = tmpsum + (1/k)*exp(-(y[j]-avect[i])^2/2);
     logresult = logresult + log(tmpsum);
   }
   return( logresult );
}

dotrans = TRUE
k = 3
kmax = 100
M = 500  # run length
if (dotrans)
  k = 1 + rpois(1,2) # overdispersed starting distribution
avect = rep(0,kmax)   # always set avect[i] = 0 for i > k.
for (i in 1:k)
  avect[i] = rnorm(1)
  # ALT: avect[i] = rnorm(1,0,5)  #overdispersed starting disribution
sigma = 0.3  # proposal scaling
klist = rep(0,M)  # for keeping track of chain k values
a1list = rep(0,M)  # for keeping track of avect[1] values
a2list = rep(0,M)  # for keeping track of avect[2] values
a3list = rep(0,M)  # for keeping track of avect[3] values
numaccept = 0;
latestlog = logg(k,avect)
print(latestlog)
cat("initial k =", k, "\n");

for (itnum in 1:M) {

  # Do a within-dimension move:
  newavect = avect;
  for (j in 1:k)
    newavect[j] = newavect[j] + sigma * rnorm(1);  # proposal value
  newlog = logg(k,newavect);
  logU = log( runif(1) );  # for accept/reject
  logA = newlog - latestlog;
  # if (logU < logA) {
  if ( (newlog > -Inf) && (logU < logA) ) {
      avect = newavect;  # accept proposal
      latestlog = newlog;
      numaccept = numaccept + 1;
  }

  # Try a between-dimension move:
  if (dotrans) {
    newk = k - 1 + 2*floor(runif(1,0,2))  # newk = k pm 1
    newavect = avect
    if (newk>k) {
      Z = rnorm(1);
      newavect[k] = avect[k] - Z;
      newavect[newk] = avect[k] + Z;
    } else {
      newavect[newk] = 0.5 * (avect[newk] + avect[k]);
      newavect[k] = 0
    }
    newlog = logg(newk,newavect)
    logA = newlog - latestlog  # for accept/reject
    if (newk > k) {
      logA = logA- (-log(2*pi)-0.5*(0.5*(newavect[newk]-newavect[k]))^2) +log(2)
    } else {
      logA = logA + (-log(2*pi) -0.5*(0.5*(avect[k]-avect[newk]))^2) - log(2)
    }
    logU = log( runif(1) )  # for accept/reject
    # if ( (!is.na(newlog)) && (newlog > -Inf) && (logU < logA) ) {
    if ( (newlog > -Inf) && (logU < logA) ) {
	# accept proposal
	k = newk;
	avect = newavect;
        # cat("itnum =", itnum, "; k =", k, "; avect =", avect, "\n\n");
	latestlog = newlog;
        # numaccept = numaccept + 1; (no, just for within-dim moves)
    }
  }

  klist[itnum] = k;
  a1list[itnum] = avect[1];
  a2list[itnum] = avect[2];
  a3list[itnum] = avect[3];
}

cat("Ran transdimensional algorithm for", M, "iterations\n");
cat("within-dimension acceptance rate =", numaccept/M, "\n");
cat("mean k =", mean(klist), "; final k =", k, "\n");
cat("final avect: ", avect[1:k], "\n");
cat("final log-density: ", latestlog, "\n");

finaldens = function(x) {
  tmpresult = 0
  for (i in 1:k)
    tmpresult = tmpresult + dnorm(x,avect[i],1)
  return(tmpresult/k);
}

plot(finaldens, -10, 25)
points(y,rep(0,length(y)))
# plot(klist,type='l')
# plot(a1list,type='l')

