/****************************************************************************
 * Cluster -- a cluster of several patterns                                 *
 ****************************************************************************
 * A cluster (or histogram bin) is an aggregate pattern associated with a   *
 * set of component patterns.                                               *
 *                                                                          *
 ****************************************************************************/

#include "cluster.h"
#include "histfile.h"

#define GENERAL_CLUSTER_TYPE "General Cluster"

Cluster::Cluster(Pattern &pattern) :
  mean(*pattern.Duplicate()), variance(*pattern.DuplicateType()) {

  variance.Zero();
  AddPattern(pattern);
}

Cluster::Cluster(Pattern *pattern) :
  mean(*pattern->Duplicate()), variance(*(pattern->DuplicateType())) {

  variance.Zero();
  AddPattern(pattern);
}

// Don't add to pattern set
Cluster::Cluster(Pattern &nmean, const Pattern &nvariance,
		 unsigned long ncount) :
  mean(*nmean.Duplicate()), variance(*nvariance.Duplicate()) {
  count = ncount;
}

// Don't add to pattern set
Cluster::Cluster(Pattern *nmean, Pattern *nvariance, unsigned long ncount) :
  mean(*nmean), variance(*nvariance) {
  count = ncount;
}

Cluster::Cluster(CArray<Pattern *, Pattern *> &pats, bool del) :
  mean(*pats[0]->Duplicate()), variance(*pats[0]->Duplicate()) {

  count = 0;

  AddPatternsShift(pats, del);
}

Cluster::Cluster(CList<Pattern *, Pattern *> &pats, bool del) :
  mean(*pats.GetHead()->Duplicate()), variance(*pats.GetHead()->Duplicate()) {

  count = 0;

  AddPatternsShift(pats, del);
}

Cluster::Cluster(Cluster &copy) :
  mean(*(copy.mean.Duplicate())), variance(*(copy.variance.Duplicate())) {
  count = copy.count;

  AddPatterns(copy.pset);
                                                             }

// Need to free memory referenced by mean and variance
Cluster::~Cluster() {
  delete &mean;
  delete &variance;

  for (PattCell m = GetFirstCell(); m.valid; IncrementCell(m))
    if (pdel.GetAt(m.delpos))
      delete pset.GetAt(m.patpos);
}

// requires that one is later emptied
const Cluster &Cluster::operator=(Cluster &copy) {
  Empty();

  mean = copy.mean;
  variance = copy.variance;
  count = copy.count;

  AddPatterns(copy.pset, copy.pdel);

  return *this;
}

// Calculations for determining variance given previous mean and variances:
//   (x1 - m1)^2 + ... + (xn - m1) and (y1 - m2)^2 + ... + (ym - m2)^2
//   (x1 - m3)^2 + ... + (xn - m3) + (y1 - m3)^2 + ... + (ym - m3)^2
//   I have m1, m2, m3, and the previous variances.
//   x1^2 - 2 x1 m1 + m1^2 + ... + xn^2 - 2 xn m1 + m1^2
//   x1^2 - 2 x1 m3 + m3^2 + ... + xn^2 - 2 xn m3 + m3^2 */
//   I can remove n * m1^2 and m * m2^2 and add (n + m) * m3^2
//     I need to change Sum (- 2 xi m1)  to Sum (- 2 xi m3)
//     Sum xi m1 = m1 * Sum xi = m1 * n * m1 = n * m1^2
//     Sum xi m3 = m3 * Sum xi = m3 * n * m1 */
//   x1^2 - 2 x1 m1 + m1^2 + ... = Sum(xi^2) - 2 n m1^2 + n * m1^2 + ...
//     -= -2 n m1^2 + n * m1^2 + ... = -n * m1^2
//     += -2 n m1 m3 + n * m3^2 + ...
// Does allocated memory for new cluster
// Requires that either result or both operands are emptied
Cluster *Cluster::operator+(Cluster &right) {
  unsigned dims = mean.GetDimensions();
  Pattern *newmean = mean.DuplicateType();
  Pattern *newvariance = variance.DuplicateType();
  unsigned long newcount;

  assert(dims == right.GetDimensions());

  newcount = count + right.count;
  for (unsigned i = 0; i < dims; i++) {
    newmean->SetResponse(i, (((double) count * mean[i]
			      + (double) right.count * right.mean[i])
			     / (double) newcount));
    newvariance->SetResponse(i, ((double) count * variance[i]
				 + (double) right.count * right.variance[i]
				 + (double) count * sqr(mean[i])
				 + (double) right.count * sqr(right.mean[i])
				 + ((double) newcount
				    * sqr(newmean->GetResponse(i)))
				 - (double) (2.0 * count * mean[i]
					     * newmean->GetResponse(i))
				 - (double) (2.0 * right.count * right.mean[i]
					     * newmean->GetResponse(i)))
			     / (double) newcount);
    if (newvariance->GetResponse(i) < 0.0)
      newvariance->SetResponse(i, midval);
  }

  Cluster *newcls = new Cluster(newmean, newvariance, 0);
  newcls->AddPatterns(right.pset, right.pdel);
  newcls->AddPatterns(pset, pdel);
  newcls->SetCount(newcount);

  return newcls;
}

// Requires that one is emptied
const Cluster &Cluster::operator+=(Cluster &right) {
  unsigned dims = mean.GetDimensions();
  Pattern *newmean = mean.DuplicateType();
  unsigned long newcount;

  assert(dims == right.GetDimensions());

  newcount = count + right.count;
  for (unsigned i = 0; i < dims; i++) {
    newmean->SetResponse(i, (((double) count * mean[i]
			      + (double) right.count * right.mean[i])
			     / (double) newcount));
    variance.SetResponse(i, ((double) count * variance[i]
			     + (double) right.count * right.variance[i]
			     + (double) count * sqr(mean[i])
			     + (double) right.count * sqr(right.mean[i])
			     + (double) newcount * sqr(newmean->GetResponse(i))
			     - (double) (2.0 * count * mean[i]
					 * newmean->GetResponse(i))
			     - (double) (2.0 * right.count * right.mean[i]
					 * newmean->GetResponse(i)))
			 / (double) newcount);
    if (variance[i] < 0.0)
      variance.SetResponse(i, midval);
  }

  AddPatterns(right.pset, right.pdel);  // requires that one is bin-removed

  count = newcount;
  mean = *newmean;

  delete newmean;

  return *this;
}

Pattern &Cluster::GetMean() const {
  return mean;
}

Pattern &Cluster::GetVariance() const {
  return variance;
}

unsigned long Cluster::GetCount() const {
  return count;
}

unsigned Cluster::GetDimensions() const {
  return mean.GetDimensions();
}

unsigned long Cluster::GetPatternCount() const {
  return pset.GetCount();
}

void Cluster::SetCount(unsigned long cnt) {
  count = cnt;
}

void Cluster::Empty() {
  count = 0;
  pset.RemoveAll();
  pdel.RemoveAll();
}

void Cluster::EmptyPatternSet() {
  pset.RemoveAll();
  pdel.RemoveAll();
}

void Cluster::operator++() {
  count++;
}

void Cluster::operator--() {
  count--;
}

void Cluster::SetToDeleteAll() {
  for (PattCell m = GetFirstCell(); m.valid; IncrementCell(m))
    pdel.SetAt(m.delpos, true);
}

void Cluster::SetToSaveAll() {
  for (PattCell m = GetFirstCell(); m.valid; IncrementCell(m))
    pdel.SetAt(m.delpos, false);
}

Pattern *Cluster::GetPattern(PattCell m) const {
  if (m.valid)
    return pset.GetAt(m.patpos);
  else
    return NULL;
}

void Cluster::SetPattern(PattCell m, Pattern *newpat) {
  pset.SetAt(m.patpos, newpat);
}

PattCell Cluster::GetFirstCell() const {
  PattCell m;

  m.patpos = pset.GetHeadPosition();
  m.delpos = pdel.GetHeadPosition();
  m.valid = (m.patpos && m.delpos);

  return m;
}

PattCell Cluster::GetNextCell(PattCell m) const {
  pset.GetNext(m.patpos);
  pdel.GetNext(m.delpos);
  m.valid = (m.patpos && m.delpos);

  return m;
}

void Cluster::IncrementCell(PattCell &m) const {
  pset.GetNext(m.patpos);
  pdel.GetNext(m.delpos);
  m.valid = (m.patpos && m.delpos);
}

// returns previous cell
PattCell Cluster::RemoveCell(PattCell m) {
  PattCell prev = m;
  pset.GetPrev(prev.patpos);
  pset.GetPrev(prev.delpos);
  prev.valid = (prev.patpos && prev.delpos);

  pset.RemoveAt(m.patpos);
  pdel.RemoveAt(m.delpos);

  if (!prev.valid)
    return GetFirstCell();
  return prev;
}

void Cluster::AddPattern(Pattern &newpat) {
  assert(GetDimensions() == newpat.GetDimensions());

  pset.AddTail(&newpat);
  pdel.AddTail(false);
  count++;
}

void Cluster::AddPattern(Pattern *newpat) {
  assert(GetDimensions() == newpat->GetDimensions());

  pset.AddTail(newpat);
  pdel.AddTail(true);
  count++;
}

void Cluster::AddPatternShift(Pattern &newpat) {
  unsigned dims = GetDimensions();
  Pattern *newmean = mean.DuplicateType();

  assert(dims == newpat.GetDimensions());

  for (unsigned i = 0; i < dims; i++) {
    newmean->SetResponse(i, ((double) count * mean[i] + newpat[i])
			 / ((double) count + 1.0));
    variance.SetResponse(i, ((double) count * variance[i]
			     + (double) count * sqr(mean[i]) + sqr(newpat[i])
			     + (((double) count + 1.0)
				* sqr(newmean->GetResponse(i)))
			     - (double) (2.0 * count * mean[i]
					 * newmean->GetResponse(i))
			     - (double) (2.0 * newpat[i]
					 * newmean->GetResponse(i)))
			 / ((double) count + 1.0));
    if (variance[i] < 0.0)
      variance.SetResponse(i, midval);
  }

  AddPattern(newpat);

  mean = *newmean;

  delete newmean;
}

void Cluster::AddPatternShift(Pattern *newpat) {
  AddPatternShift(*newpat);
  pdel.SetAt(pdel.GetTailPosition(), true);
}

void Cluster::AddPatterns(CList<Pattern *, Pattern *> &pats,
			  CList<bool, bool> &dels) {
  pset.AddTail(pats);
  pdel.AddTail(dels);
}

void Cluster::AddPatterns(CList<Pattern *, Pattern *> &pats, bool del) {
  pset.AddTail(pats);
  for (unsigned long m = 0; m < pats.GetCount(); m++)
    pdel.AddTail(del);
  count += pats.GetCount();
}

void Cluster::AddPatterns(CArray<Pattern *, Pattern *> &pats, bool del) {
  for (unsigned long m = 0; m < pats.GetSize(); m++)
    if (del)
      AddPattern(pats[m]);
    else
      AddPattern(*pats[m]);
}

void Cluster::AddPatternsShift(CList<Pattern *, Pattern *> &pats, bool del) {
  POSITION m = pats.GetHeadPosition();
  while (m)
    if (del)
      AddPatternShift(pats.GetNext(m));
    else
      AddPatternShift(*pats.GetNext(m));
}

void Cluster::AddPatternsShift(CArray<Pattern *, Pattern *> &pats, bool del) {
  for (unsigned long m = 0; m < pats.GetSize(); m++)
    if (del)
      AddPatternShift(pats[m]);
    else
      AddPatternShift(*pats[m]);
}

// Count the patterns which are NearByResponse this cluster
unsigned long Cluster::CountNears(const CArray<Pattern *, Pattern *> &pats)
  const {
  unsigned long nears = 0;

  for (unsigned long m = 0; m < pats.GetSize(); m++)
    if (mean.NearByResponse(*pats[m]))
      nears++;

  return nears;
}

// Output single cluster to file
void Cluster::WriteCluster(FILE *fp) const {
  unsigned long m;

  fwrite(&count, sizeof(unsigned long), 1, fp);  // population count of cluster
  mean.WritePattern(fp);                         // cluster mean
  variance.WritePattern(fp);                     // cluster variance

  m = pset.GetCount();                           // count patterns in set
  fwrite(&m, sizeof(unsigned long), 1, fp);

  for (PattCell m = GetFirstCell(); m.valid; IncrementCell(m))
    GetPattern(m)->WritePattern(fp);       // patterns in set
}

// Output type and data specification
void Cluster::WriteClusterType(FILE *fp) const {
  fprintf(fp, "%s\n", GENERAL_CLUSTER_TYPE);

  fputc(HF_ROLE_COUNT, fp);
  fputc(HF_TYPE_ULONG, fp);
  fputc(1, fp);

  fputc(HF_ROLE_AVERG, fp);
  fputc(HF_TYPE_PATRN, fp);
  fputc(1, fp);

  fputc(HF_ROLE_VARNC, fp);
  fputc(HF_TYPE_PATRN, fp);
  fputc(1, fp);

  fputc(HF_ROLE_PNSET, fp);
  fputc(HF_TYPE_ULPNA, fp);  // unsigned long, pattern array
  fputc(1, fp);

  fputc(0, fp);
}

void Cluster::WriteClusterGlobals(FILE *fp) const {
}

void Cluster::PrintCluster() const {
  mean.PrintPattern();
  printf(" x %ld\n", count);
}

// Output type and data specification
void Cluster::WritePatternType(FILE *fp) const {
  mean.WritePatternType(fp);
}

void Cluster::WritePatternGlobals(FILE *fp) const {
  mean.WritePatternGlobals(fp);
}

// checks type
Cluster *Cluster::ReadClusterType(FILE *fp, Pattern *tmpl) {
  assert(fgetc(fp) == HF_ROLE_COUNT);
  assert(fgetc(fp) == HF_TYPE_ULONG);
  assert(fgetc(fp) == 1);

  assert(fgetc(fp) == HF_ROLE_AVERG);
  assert(fgetc(fp) == HF_TYPE_PATRN);
  assert(fgetc(fp) == 1);

  assert(fgetc(fp) == HF_ROLE_VARNC);
  assert(fgetc(fp) == HF_TYPE_PATRN);
  assert(fgetc(fp) == 1);

  assert(fgetc(fp) == HF_ROLE_PNSET);
  assert(fgetc(fp) == HF_TYPE_ULPNA);
  assert(fgetc(fp) == 1);

  Pattern &copy = *tmpl;
  delete tmpl;
  return new Cluster(copy);
}

// Modifies self; returns *this
Cluster &Cluster::ReadCluster(FILE *fp) {
  Cluster *newcls;
  unsigned long clscount, setcount;

  // Read in the pattern count
  fread(&clscount, sizeof(unsigned long), 1, fp);

  // Read in mean and variance -- already has appropraite pattern type
  mean.ReadPattern(fp);
  variance.ReadPattern(fp);

  // Read in set
  pset.RemoveAll();
  pdel.RemoveAll();
  fread(&setcount, sizeof(unsigned long), 1, fp);
  for (unsigned long m = 0; m < setcount; m++)
    AddPattern(mean.DuplicateType()->ReadPattern(fp));

  count = clscount;  // replace number from AddPatterns

  return *this;
}

Cluster *Cluster::SelectClusterType(FILE *fp, Pattern *tmpl) {
  char filetype[128];

  fgets(filetype, 128, fp);
  filetype[strlen(filetype) - 1] = '\0'; // Drop newline

  if (!strcmp(GAUS_PATTERN_TYPE, filetype))
    return ReadClusterType(fp, tmpl);
}
