/****************************************************************************
 * Histogram -- an array of clusters, used for probability models           *
 ****************************************************************************
 * Every cluster forms a bin in the histogram.                              *
 *                                                                          *
 ****************************************************************************/

#include "histogram.h"
#include "gausimage.h"
#include "cluster.h"

// Each time add new cluster/pattern, assert that has same dims, so just have
//   to check first (better yet, save info, then check all against)

// Currently a mess, because trying to deal with CArray to CList change
//   (but the change is appropriate)

#define COUNT_OT        .8

// Useful Macros
#define sqr(a)   ((a) * (a))

// 0 / 0 = 1
#define safed(a, b) ((a) ? ((b) / (a)) : ((b) ? ((b) / 1e-5) : 1.0))
// a / b ~= 0, b / a ~= oo
#define a(a)     ((a) ? (a) : 1e-10)
#define b(a)     ((a) ? (a) : 1e-5)

Histogram::Histogram() {
  dims = 0;
}

Histogram::Histogram(Image &img) {
  GausImage base(img);
  dims = 0;
  Histogram(*base.MakePatterns(), true);
}

Histogram::Histogram(const char *filename) {
  dims = 0;
  LoadClusters(filename);
}

Histogram::Histogram(Histogram &copy) {
  dims = 0;
  for (HistBin m = copy.GetFirstBin(); m; copy.IncrementBin(m))
    AddCluster(copy.GetCluster(m));
}

Histogram::Histogram(CArray<Pattern *, Pattern *> &pats, bool del) {
  dims = 0;
  AddClusters(pats, del);
}

Histogram::~Histogram() {
  RemoveAllBins();
}

unsigned Histogram::GetDimensions() const {
  return dims;
}

unsigned long Histogram::GetCount() const {
  return bins.GetCount();
}

unsigned long Histogram::GetPatternCount() const {
  return patterncount;
}

Cluster &Histogram::GetCluster(HistBin m) const {
  return *bins.GetAt(m);
}

void Histogram::SetCluster(HistBin m, Cluster &newcls) {
  patterncount -= bins.GetAt(m)->GetCount();
  patterncount += newcls.GetCount();
  *bins.GetAt(m) = newcls;
}

HistBin Histogram::GetFirstBin() const {
  return bins.GetHeadPosition();
}

HistBin Histogram::GetNextBin(HistBin m) const {
  bins.GetNext(m);  // okay to modify our copy of m
  return m;
}

void Histogram::IncrementBin(HistBin &m) const {
  bins.GetNext(m);
}

HistBin Histogram::RemoveBin(HistBin m) {
  HistBin prev = m;
  bins.GetPrev(prev);

  patterncount -= bins.GetAt(m)->GetCount();
  delete bins.GetAt(m);
  bins.RemoveAt(m);

  if (!prev)
    return GetFirstBin();
  return prev;
}

void Histogram::RemoveAllBins() {
  HistBin m = bins.GetTailPosition();

  while (m)
    m = RemoveBin(m);

  patterncount = 0;
}

Cluster &Histogram::AddCluster(Cluster &newcls) {
  if (!dims)
    dims = newcls.GetDimensions();
  else
    assert(dims == newcls.GetDimensions());

  bins.AddTail(new Cluster(newcls));
  patterncount += newcls.GetCount();

  return newcls;
}

Cluster &Histogram::AddCluster(Pattern &newpat) {
  return AddCluster(new Cluster(newpat));
}

Cluster &Histogram::AddCluster(Cluster *newcls) {
  if (!dims)
    dims = newcls->GetDimensions();
  else
    assert(dims == newcls->GetDimensions());

  bins.AddTail(newcls);
  patterncount += newcls->GetCount();

  return *newcls;
}

Cluster &Histogram::AddCluster(Pattern *newpat) {
  Cluster &bincls = AddCluster(new Cluster(newpat));
  return bincls;
}

void Histogram::AddClusters(const char *filename) {
  LoadClusters(filename);
}

void Histogram::AddClusters(CArray<Pattern *, Pattern *> &pats, bool del) {
  for (unsigned long m = 0; m < pats.GetSize(); m++)
    if (del)
      AddCluster(pats[m]);
    else
      AddCluster(*pats[m]);
}

// 0 for success, otherwise negative
int Histogram::LoadClusters(const char *filename) {
  FILE *fp = fopen(filename, "rb");
  Histogram *newhist;
  Cluster *newcls;
  Pattern *newpat;
  char filetype[128];
  unsigned char ver1, ver2;
  unsigned long count;

  // Read Header
  do
    fgets(filetype, 128, fp);
  while (filetype && *filetype == '#');
  if (!filetype) {
    printf("LoadClusters incountered corrupted file (at file type)\n");
    return CORRUPT_FILE_ERR;
  }
  filetype[strlen(filetype) - 1] = '\0'; // Drop newline

  ver1 = fgetc(fp);  // version
  ver2 = fgetc(fp);

  // Check file type
  if (strcmp(filetype, CLUSTER_PROGRAM) || ver1 != CLUSTER_PRGVER1 ||
      ver2 != CLUSTER_PRGVER2) {
    if (strcmp(filetype, CLUSTER_PROGRAM))
      printf("LoadClusters incountered unexpected file type (%s)\n", filetype);
    else if (ver1 > CLUSTER_PRGVER1 ||
	     (ver1 == CLUSTER_PRGVER1 && ver2 >= CLUSTER_PRGVER2))
      printf("LoadClusters can't read later version (%d.%d); Upgrade.\n",
	     ver1, ver2);
    else
      printf("LoadClusters can't read outdated file (%d.%d); Please update.\n",
	     ver1, ver2);
    return BAD_VERSION_ERR;
  }

  newpat = Pattern::SelectPatternType(fp);
  newcls = Cluster::SelectClusterType(fp, newpat);
  fread(&count, sizeof(unsigned long), 1, fp);

  // Read in bins
  for (unsigned m = 0; m < count; m++)
    AddCluster(newcls->ReadCluster(fp));

  return 0;
}

// Writes out cluster data to a file
// Want the file type to be reusable
// Perhaps later add user defined structures (and recursive defs)
// Currently reasonable general but not for everything
int Histogram::SaveClusters(const char *filename) const {
  FILE *fp = fopen(filename, "wb");

  // Write Header
  // File type and version
  fprintf(fp, "#\n%s\n", CLUSTER_PROGRAM);
  fputc(CLUSTER_PRGVER1, fp);
  fputc(CLUSTER_PRGVER2, fp);

  // Global cluster info
  bins.GetAt(GetFirstBin())->WritePatternType(fp);
  bins.GetAt(GetFirstBin())->WriteClusterType(fp);
  bins.GetAt(GetFirstBin())->WritePatternGlobals(fp);
  bins.GetAt(GetFirstBin())->WriteClusterGlobals(fp);

  unsigned long count = GetCount();
  fwrite(&count, sizeof(unsigned long), 1, fp);

  // Output Clusters
  for (HistBin m = GetFirstBin(); m; IncrementBin(m))
    GetCluster(m).WriteCluster(fp);

  return 0;
}

void Histogram::PrintHistogram() const {
  for (HistBin m = GetFirstBin(); m; IncrementBin(m))
    GetCluster(m).PrintCluster();
}

// Only calculate once per histogram; best done before clustering
// Sums over the responses of every cluster
// Count is the number of patterns (rows * cols)
Cluster *Histogram::NormalCluster() const {
  Pattern *mean = bins.GetHead()->GetMean().DuplicateType();
  Pattern *variance = bins.GetHead()->GetVariance().DuplicateType();
  unsigned long total;

  for (HistBin m = GetFirstBin(); m; IncrementBin(m)) {
    total += GetCluster(m).GetCount();
    *mean += GetCluster(m).GetMean();
  }

  *mean /= total;

  Pattern *temp1, *temp2;
  for (HistBin m = GetFirstBin(); m; IncrementBin(m)) {
    *variance += *(temp2 = (temp1 = GetCluster(m).GetMean() - *mean)->Sqr());
    delete temp1;
    delete temp2;
  }
  *variance /= total;

  return new Cluster(mean, variance, total);
}

// Note there is some polarization to creating clusters, as well as some
//   unfair (for testing) changing of the data over several passes
void Histogram::SelfCluster() {
  unsigned long remains = GetCount(), desired = 2, prevrem = 0;

  // Elements of clusters array for deleted clusters are set to null
  while (remains > desired && remains != prevrem) {
    prevrem = remains;
    // Compare each cluster to each other; ignore deleted clusters
    for (HistBin m = GetFirstBin(); m; IncrementBin(m)) {
      for (HistBin n = GetNextBin(m); n; IncrementBin(n))
	if (GetCluster(m).GetMean().NearByResponse(GetCluster(n).GetMean())) {
	  // A Success! Combine clusters...
	  GetCluster(m) += GetCluster(n);
	  GetCluster(n).EmptyPatternSet();
	  n = RemoveBin(n);
	  remains--;
	}
    }
  }
}

// Make this a clustering of the passed patterns, with respect to my clusters
// Cluster the patterns in this histogram with respect to those in the model
// This is not a nearest neighbor partition: a cluster may belong to several
//   model clusters
// Paradigm: should modify self, not passed; should not do much with memory
// Replaces own counts from clusters in the passed histogram
Histogram *Histogram::TestHistogram(const CArray<Pattern *, Pattern *> &pats) {
  Histogram *newhist = new Histogram(*this);

  for (HistBin m = GetFirstBin(); m; IncrementBin(m)) {
    newhist->GetCluster(m).Empty();
    newhist->GetCluster(m).SetCount(GetCluster(m).CountNears(pats));
  }

  return newhist;
}

// This partitioning algorithm is described in "Cluster-Based Probability Model
//   and Its Application to Image and Texture Processing" (Popat and Picard)
void Histogram::LBGCluster(CArray<Pattern *, Pattern *> &pats) {
  float D, prevD;

  Cluster *prefcls; // For nearest neighbor
  double prefdist, testdist;

  unsigned long cells = 1, largecells = 1;
  unsigned long prevcells;

  // Arbitrary chosen variables
  unsigned long desired = 100, largethrhld = 10;
  float clchange = .01, Dchange = .01;

  // Clean histogram
  RemoveAllBins();
  dims = pats[0]->GetDimensions();

  // Start out with a single, average partition
  AddCluster(new Cluster(pats));

  // Initial value of D, total squared distortion (from total variance)
  D = bins.GetHead()->GetVariance().SumReduce() * GetCount();

  // Loop until desired clustering
  while (cells >= desired || !largecells) {

    // Split every (large) histogram bin in two
    prevcells = cells;
    cells = 0;
    HistBin m = GetFirstBin();
    for (unsigned int n = 0; n < prevcells; n++, IncrementBin(m)) {
      Cluster &currcls = GetCluster(m);
      cells++;

      if (currcls.GetCount() < largethrhld) {
	// First, duplicate bin
	Cluster &newcls = AddCluster(GetCluster(m));
	newcls.Empty();
	cells++;

	// Shift cluster a random amount in each dimension
	newcls.GetMean().RandomShift(clchange);
      }

      // Reset histogram bin sets
      currcls.Empty();
    }

    // Now partition patterns
    do {
      prevD = D;

      // Nearest neighbor parition of clusters
      largecells = 0;

      for (unsigned long n = 0; n < pats.GetSize(); n++) {
	prefcls = NULL;
	prefdist = largeval;

	// Find nearest neighbor to closest
	for (m = GetFirstBin(); m; IncrementBin(m)) {
	  if ((testdist = pats[n]->CartesianDistance(GetCluster(m).GetMean()))
	      < prefdist) {
	    prefcls = &(GetCluster(m));
	    prefdist = testdist;
	  }
	}

	// Add cluster to set associated with the prefered cluster
	prefcls->AddPatternShift(pats[n]);
	if (prefcls->GetCount() == largethrhld)
	  largecells++;
      }

      // Recalculate mean based on new sets
      // Also, calculate D, the total squared distortion, by each set
      D = 0.0;
      for (m = GetFirstBin(); m; IncrementBin(m))
	D += GetCluster(m).GetVariance().SumReduce();
    } while (D < prevD * Dchange);
  }
}

// TODO: Place elsewhere (with new ClusterTree class?)
// Then, here, split tree up for clusters
void Histogram::GraphCluster(CArray<LocPattern *, LocPattern *> &pats,
			     double dist) {
  RemoveAllBins();
  puts("A");

  // Check that all patterns are LocPatterns and find dimensions
  unsigned rowcnt = 0;
  unsigned colcnt = 0;
  for (unsigned long m = 0; m < pats.GetSize(); m++) {
    rowcnt = max(pats[m]->GetRow() + 1, rowcnt);
    colcnt = max(pats[m]->GetColumn() + 1, colcnt);
  }
  puts("B");

  // Generate Minimum Spanning Tree
  // Calculate total number of edges
  unsigned long edgecnt = 4 * rowcnt * colcnt // 8 surrounding each pixel, once
    - 3 * rowcnt - 3 * colcnt + 2; // except down-right borders, except corners

  // Place clusters into grid for easy access
  ClusterNode grid[rowcnt][colcnt];
  for (unsigned long m = 0; m < pats.GetSize(); m++)
    grid[pats[m]->GetRow()][pats[m]->GetColumn()].SetPattern(*pats[m]);
  puts("C");

  // Generate all edges
  ClusterEdge *edges[edgecnt];
  unsigned long cedge = 0;
  for (unsigned long r = 0; r < rowcnt; r++)
    for (unsigned long c = 0; c < colcnt; c++) {
      if (c < colcnt - 1)                                 // Edge to Right
	edges[cedge++] = new ClusterEdge(grid[r][c], grid[r][c + 1]);
      if (r < rowcnt - 1 && c < colcnt - 1)               // Edge to Down-Right
	edges[cedge++] = new ClusterEdge(grid[r][c], grid[r + 1][c + 1]);
      if (r < rowcnt - 1)                                 // Edge to Down
	edges[cedge++] = new ClusterEdge(grid[r][c], grid[r + 1][c]);
      if (r < rowcnt - 1 && c > 0)                        // Edge to Down-Left
	edges[cedge++] = new ClusterEdge(grid[r][c], grid[r + 1][c - 1]);
    }
  assert(cedge == edgecnt);
  puts("D");

  // Sort edges by weight
  qsort(edges, edgecnt, sizeof(ClusterEdge *), EdgeCompare2);
  puts("Q");

  // Make the graph
  for (cedge = 0; cedge < edgecnt && edges[cedge]->GetWeight() < dist; cedge++)
    edges[cedge]->Mark();
  puts("E");

  // Use trees from grid to generate new clusters
  ClusterTree usedflag;
  for (cedge = 0; cedge < edgecnt && edges[cedge]->GetWeight() < dist;
       cedge++) {

    // Check if this edge is already in a used tree
    ClusterTree &toptree = edges[cedge]->ScaleToTop();
    if (toptree.IsInTree(usedflag))
      continue;

    // Add new cluster
    Cluster *newcls = toptree.MakeCluster();
    newcls->SetToDeleteAll();  // Patterns are now its
    AddCluster(newcls);

    // Mark edge as part of a used tree
    toptree.CombineWith(usedflag);
  }
  puts("F");

  // Free generation structures
  delete &(usedflag.ScaleToTop());
  do
    delete edges[cedge];
  while (++cedge < edgecnt);
  puts("G");
}
  

// Eliminates matching clusters that are high in out of class patterns
// Keeps track of next position before needed so that removed can continue
//   as remove bins
void Histogram::EliminateBins(const CArray<Pattern *, Pattern *> &pats) {
  printf("Here: %ld to %ld\n", pats.GetSize(), GetCount());
  for (HistBin m = GetFirstBin(); m; IncrementBin(m))
    if (GetCluster(m).GetCount() /
	(GetCluster(m).GetCount() + GetCluster(m).CountNears(pats)) < COUNT_OT)
      m = RemoveBin(m);
  printf("Done: %ld to %ld\n", pats.GetSize(), GetCount());
}

// Returns the number of clusters remaining
unsigned long Histogram::RemoveSingles() {
  unsigned long remains = 0;

  for (HistBin m = GetFirstBin(); m; IncrementBin(m))
    if (GetCluster(m).GetCount() == 1)
      m = RemoveBin(m);
    else
      remains++;

  return remains;
}

// Compares pattern to each of the clusters, determining that one pattern's
//   probability of being generated in the histogram of all the clusters
double Histogram::ProbByPattern(Pattern &pattern) const {
  double prob = 0.0, gaus;
  unsigned total;
  Pattern *currmean;
  Pattern *currvar;

  assert(pattern.GetDimensions() == dims);

  for (HistBin m = GetFirstBin(); m; IncrementBin(m))
    total += GetCluster(m).GetCount();

  // Sum m=1 to count (Wm * Product).  Wm = Popm / N.  Product below.
  for (HistBin m = GetFirstBin(); m; IncrementBin(m)) {
    currmean = &(GetCluster(m).GetMean());
    currvar = &(GetCluster(m).GetVariance());
    gaus = 1.0;

    // Product i=1 to d (Kmi(v[i])).  Kmi(v[i]) defined below.
    for (unsigned i = 0; i < pattern.GetDimensions(); i++)
      gaus *= a(exp(-b(sqr(pattern[i] - currmean->GetResponse(i)))
		    / a(2.0 * currvar->GetResponse(i))))
	/ b(sqrt(2.0 * M_PI * currvar->GetResponse(i)));

    prob += ((double) GetCluster(m).GetCount() / (double) total) * gaus;
  }

  printf(" %d, %g\n", dims, prob);

  return prob;
}

// Calculate Cross entropy:
//   The Test Histogram is the "guessed" abundance of each of the clusters
//   However, we know the real abundance of them
//   Calculate the difference: p is guessed, q is true, w = bin-pop / total-pop
//   D(p||q) = Int(p(x) log(p(x)/q(x))) ~= w pi(x) log(pi(x)/qi(x) per cluster
double Histogram::EntropyByClass(Histogram &test) const {
  double prob = 0.0;
  double testprob;
  unsigned long bintotal = 0, tsttotal = 0;
  double dbintotal, dtsttotal;
  HistBin m, n;

  assert(GetCount() == test.GetCount());

  // Count total in test clusters and total in my clusters
  for (m = GetFirstBin(), n = test.GetFirstBin(); m && n;
       IncrementBin(m), test.IncrementBin(n)) {
    bintotal += GetCluster(m).GetCount();
    tsttotal += test.GetCluster(n).GetCount();
  }

  dbintotal = (double) bintotal;
  dtsttotal = (double) tsttotal;

  /* Calculate Sum */
  for (m = GetFirstBin(), n = test.GetFirstBin(); m && n;
       IncrementBin(m), test.IncrementBin(n)) {
    testprob = (double) test.GetCluster(m).GetCount() / dtsttotal;
    prob += fabs(testprob * log(testprob * dbintotal
				/ (double) GetCluster(m).GetCount()));
  }

  return prob / (double) GetCount();
}

// Calculates the chi-square test for differences between the histograms for
//   the model image (in the clusters) and the test image
// Uses Normal Distribution to convert to approximate probability
double Histogram::ChiSqrByClass(Histogram &test) const {
  double chisqr = 0.0;
  double mean = 0.0;
  HistBin m, n;

  assert(GetCount() == test.GetCount());

  for (m = GetFirstBin(), n = test.GetFirstBin(); m && n;
       IncrementBin(m), test.IncrementBin(n)) {
    chisqr += (double) sqr(GetCluster(m).GetCount()
			   - test.GetCluster(n).GetCount())
      / (double) GetCluster(m).GetCount();
    mean += (double) GetCluster(n).GetCount();
  }

  mean /= (double) GetCount(); // Not used

  return chisqr;
}

// Calculates the chi-square test for differences between the histograms for
//   the model image (in the clusters) and the test image
// Uses given weighting factors to correctly use multiple images
// Chi^2 = Sum_i((Sqrt(Sum_k(ModelImgPixCnt_k) / TestImgPixCnt) ModelBin_i -
//                Sqrt(TestImgPixCnt / Sum_k(ModelImgPixCnt_k)) TestBin_i)^2
//               / ModelBin_i)
double Histogram::MultipleChiSqrByClass(Histogram &test, double scale) const {
  double chisqr = 0.0;
  double mean = 0.0;
  HistBin m, n;

  assert(GetCount() == test.GetCount());

  for (m = GetFirstBin(), n = test.GetFirstBin(); m && n;
       IncrementBin(m), test.IncrementBin(n)) {
    chisqr += (double) sqr(GetCluster(m).GetCount() * scale
			   - test.GetCluster(n).GetCount() / scale)
      / (double) GetCluster(m).GetCount();
    mean += (double) GetCluster(n).GetCount();
  }

  mean /= (double) GetCount(); // Not used

  return chisqr;
}

// Generates a factor to divide the results of ProbClusterByClass so that a
//   perfect match of histogram will be probability 1.  Note that the image
//   that generated that histogram may still have probability < 1, because the
//   histogram does not accurately represent it
// Try dividing this value by count for added fun
double Histogram::NormalizeProb() const {
  double prob = 0.0;

  for (HistBin m = GetFirstBin(); m; IncrementBin(m))
    prob += ProbByPattern(GetCluster(m).GetMean()) * (double) GetCount();

  return prob;
}
