//
// MVC_1RSB_Single.cpp
//
// Studying the minimal vertex cover problem by first-step
// replica-symmetry-breaking mean-field theory.
//
// Input: an instance of the random graph with mean degree c.
//
//For more information about the underlying mean-field theory, see Chapter 5
//of the book:
//Hai-Jun Zhou, "Spin Glass and Message Passing"
//(Science Press, Beijing, 2015).
//
// HISTORY:
// 27-30.09.2012
//
// PROGRAMMER:
// Haijun Zhou
// Institute of Theoretical Physics, Chinese Academy of Sciences,
// Zhong-Guan-Cun East Road 55, Beijing 100190, China
// http://www.itp.ac.cn/~zhouhj
// email: zhouhj@itp.ac.cn
//


#include <cmath>
#include <exception>
#include <fstream>
#include <iostream>
#include <string>
#include <valarray>
#include "zhjrandom.h"

using namespace std;


//coarse-grained survey-propagation message
struct mstruct
{
  double weight_warning; //probability of warning messages
  mstruct(double);
};
mstruct::mstruct(double ww=0)
{
  weight_warning=ww;
  return ;
}


//pointer to weight-warning message struct
struct mptrstruct
{
  struct mstruct *mptr;
  mptrstruct(void);
};
mptrstruct::mptrstruct(void)
{
  mptr=0;
  return ;
}


//variable node struct
struct vstruct
{
  int degree; //number of nearest neighbors
  struct mstruct *omptr; //start position of output message
  struct mptrstruct *imptrptr; //start position of input address list
  vstruct(void);
};
vstruct::vstruct(void)
{
  degree=0;
  omptr=0;
  imptrptr=0;
  return ;
}



class MVCsurveypropagation
{

public:
  MVCsurveypropagation(ZHJRANDOMv3* );

  void SetY(double);   //reset Y

  void SetDamping(double);   //reset DampingFactor

  bool Graph( string&, int);

  void InitialMessage0(void); //all messages initialized to be zero

  void InitialMessageR(void); //messages initialized uniform in (0,1)

  bool SurveyPropagation(double,int);

  void Statistics( string &,  string&);

private:
  void UpdateMessage(struct vstruct*, double&); //update vertex out-messages 

  double Y; //inverse temperature at level of macrostates

  double Expmy; //exp(-Y)

  double OneMexpmy; //1-exp(-Y)

  double DampingFactor;

  ZHJRANDOMv3 *rdptr;

  valarray<struct vstruct> Vertex;

  valarray<struct mstruct> OutMessage;

  valarray<struct mptrstruct> InMessageAddress;

  int VertexNumber;

  int EdgeNumber;

  int MaxDegree; //maximal vertex degree in the graph

  valarray<int> Permutation; //array used in random sequential updating

  valarray<double> Product; //array used in UpdateMessage()
};


//constructor
MVCsurveypropagation::MVCsurveypropagation(ZHJRANDOMv3 *rd)
{
  rdptr=rd;
  VertexNumber=0;
  EdgeNumber=0;
  MaxDegree=0;
}


//set inverse temperature at the level of macrostates
void MVCsurveypropagation::SetY(double yval)
{
  Y = yval;
  Expmy=exp(-Y);
  return ;
}


//set damping factor of survey propagation updating
void MVCsurveypropagation::SetDamping(double dval)
{
  DampingFactor=dval; //belong to (0,1] interval. 0: no updating; 1: no damping
  return ;
}


//read edge set
bool MVCsurveypropagation::Graph
(
  string& graphname,   //file name
 int enumber   //number of edges to be read
 )
{
  ifstream graph(graphname.c_str());
  if(!graph.good())
    {
      cerr<<"Graph probably non-existant.\n";
      return false;
    }
  graph
    >>VertexNumber
    >>EdgeNumber
    ;
  if(EdgeNumber<enumber)
    {
      cerr<<"Not so many edges in the graph.\n";
      graph.close();
      return false;
    }
  //only the first enum edges will be read
  EdgeNumber=enumber;
  try { Vertex.resize(VertexNumber+1); }
  catch(bad_alloc) 
    { cerr<<"Vertex not initialized.\n"; return false; }
  for(int eindex=0; eindex<EdgeNumber; ++eindex)
    {
      int v1,v2;
      graph
	>>v1
	>>v2
	;
      if(v1==v2 || v1==0 || v1>VertexNumber || v2==0 || v2>VertexNumber)
	{
	  cerr<<"Graph incorrect at line "<<eindex+1<<endl;
	  graph.close();
	  return false;
	}
      ++(Vertex[v1].degree);
      ++(Vertex[v2].degree);
    }
  graph.close();
  try { OutMessage.resize(2*EdgeNumber); }
  catch(bad_alloc) 
    { cerr<<"OutMessage not initialized.\n"; return false; }
  try { InMessageAddress.resize(2*EdgeNumber); }
  catch(bad_alloc)
    { cerr<<"InMessageAddress not initialized.\n"; return false; }
  int position=0;
  MaxDegree=0;
  for(int v=1; v<=VertexNumber; ++v)
    {
      Vertex[v].omptr = &OutMessage[position];
      Vertex[v].imptrptr = &InMessageAddress[position];
      position += Vertex[v].degree;
      if(Vertex[v].degree>MaxDegree) MaxDegree=Vertex[v].degree;
      Vertex[v].degree=0;
    }

  graph.open(graphname.c_str());
  graph
    >>VertexNumber
    >>EdgeNumber
    ;
  EdgeNumber=enumber;
  struct mptrstruct *imptrptr=&InMessageAddress[0];
  for(int eindex=0; eindex<EdgeNumber; ++eindex)
    {
      int v1,v2;
      graph
	>>v1
	>>v2
	;
      imptrptr=Vertex[v1].imptrptr + Vertex[v1].degree;
      imptrptr->mptr = Vertex[v2].omptr + Vertex[v2].degree;
      imptrptr=Vertex[v2].imptrptr + Vertex[v2].degree;
      imptrptr->mptr = Vertex[v1].omptr + Vertex[v1].degree;
      ++(Vertex[v1].degree);
      ++(Vertex[v2].degree);
    }
  graph.close();

  cout
    <<VertexNumber
    <<'\t'
    <<EdgeNumber
    <<'\t'
    <<(2.0e0*EdgeNumber)/(1.0e0*VertexNumber)  //mean vertex degree
    <<endl
    ;
  try { Permutation.resize(VertexNumber); }
  catch(bad_alloc)
    { cerr<<"Permutation not initialized.\n"; return false; }
  for(int v=0; v<VertexNumber; ++v) Permutation[v]=v+1;
  try { Product.resize(MaxDegree); }
  catch(bad_alloc)
    { cerr<<"Product not initialized.\n"; return false; }

  return true ;
}


//all messages initialized to be zero
void MVCsurveypropagation::InitialMessage0(void)
{
  struct mstruct *omptr = &OutMessage[0];
  for(int v=1; v<=VertexNumber; ++v)
    {
      int degree=Vertex[v].degree;
      for(int d=0; d<degree; ++d)
	{
	  omptr->weight_warning = 0;
	  ++omptr;
	}
    }
  return ;
}


//messages initialized to be uniformly distributed
void MVCsurveypropagation::InitialMessageR(void)
{
  struct mstruct *omptr = &OutMessage[0];
  for(int v=1; v<=VertexNumber; ++v)
    {
      int degree=Vertex[v].degree;
      for(int d=0; d<degree; ++d)
	{
	  omptr->weight_warning = rdptr->rdflt();
	  ++omptr;
	}
    }
  return ;
}


//update the output weight warnings from a vertex
void MVCsurveypropagation::UpdateMessage
(
 struct vstruct *vptr,
 double& maxdiff //message difference
 )
{
  maxdiff=0;

  int degree=vptr->degree;
  if(degree==0) return ; //no updating

  for(int d=0; d<degree; ++d) Product[d]=1.0e0;
  struct mptrstruct *imptrptr = vptr->imptrptr;
  for(int d=0; d<degree; ++d)
    {
      double value=1.0e0 - imptrptr->mptr->weight_warning;
      for(int k=0; k<degree; ++k)
	if(k!=d) Product[k] *= value;
      ++imptrptr;
    }

  struct mstruct *omptr=vptr->omptr;
  for(int d=0; d<degree; ++d)
    {
      double diff = Product[d]/(Expmy +OneMexpmy*Product[d]) 
	- omptr->weight_warning;
      if(abs(diff)>maxdiff) maxdiff=abs(diff);
      omptr->weight_warning += DampingFactor*diff;
      ++omptr;
    }

  return ;
}


//Try to find a fixed point of the survey propagation
bool MVCsurveypropagation::SurveyPropagation
(
 double error, //error threshold
 int count //iteration steps
 )
{
  int iter=0, number;
  double MaxDiff, Diff;
  do {
    MaxDiff=0;
    for(int quant=VertexNumber; quant>0; --quant)
      {
	int iii = static_cast<int>(quant * rdptr->rdflt());
	int v = Permutation[iii];
	Permutation[iii] = Permutation[quant-1];
	Permutation[quant-1] = v;
	UpdateMessage(&Vertex[v], Diff);
	if(Diff>MaxDiff) MaxDiff=Diff;
      }
    if((iter % 10)==0) 
      { 
	cerr
	  << MaxDiff
	  <<'\t'
	  ;
	cerr.flush();
      }
    ++iter;
  } while (MaxDiff>error && iter<count);
  if(MaxDiff<=error)
    {
      cerr
	<<'\t'
	<<MaxDiff
	<<"  :-)\n"
	;
      return true;
    }
  else
    {
      cerr
	<<'\t'
	<<MaxDiff
	<<"  :-(\n"
	;
      return false;
    }
}


//Calculate mean energy, complexity, r0,rstar,r1, etc.
void MVCsurveypropagation::Statistics
(
  string& output1name, //single vertex
  string& output2name //average
 )
{
  double
    yGvertex=0, //vertex contribution to yG
    yGedge=0, //edge contribution to yG
    Evertex=0, //vertex contribution to E
    Eedge=0, //edge contribution to E
    r0=0, //average value of frozen-0 fraction
    rstar=0; //average value of unfrozen fraction

  double r0val,rstarval,normalv,normale,value,weight,weight_warning;

  struct vstruct *vptr=&Vertex[1];
  struct mptrstruct *imptrptr = vptr->imptrptr;
  struct mstruct *omptr=vptr->omptr;

  ofstream single(output1name.c_str());
  for(int v=1; v<=VertexNumber; ++v, ++vptr)
    {
      int degree=vptr->degree;
      if(degree==0)
	{
	  r0 += 1;
	  single
	    <<v
	    <<'\t'
	    <<1      //r0
	    <<'\t'
	    <<0      //r*
	    <<'\t'
	    <<0      //r1
	    <<endl;
	  continue;
	}
      imptrptr=vptr->imptrptr;
      omptr=vptr->omptr;
      for(int d=0; d<degree; ++d) Product[d]=1.0e0;
      weight=1.0e0;
      for(int d=0; d<degree; ++d)
	{
	  weight_warning=imptrptr->mptr->weight_warning;
	  value=1.0e0-weight_warning;
	  weight *= value;
	  for(int k=0; k<degree; ++k)
	    {
	      if(k!=d) Product[k] *= value;
	      else Product[k] *= weight_warning;
	    }

	  value=weight_warning*omptr->weight_warning;
	  normale=1.0e0-OneMexpmy*value;
	  yGedge += -log(normale);
	  Eedge +=(Expmy*value)/normale;

	  ++imptrptr;
	  ++omptr;
	}
      normalv=Expmy+OneMexpmy*weight;
      yGvertex += -log(normalv);
      Evertex += (Expmy*(1.0e0-weight))/normalv;
      r0val=weight/normalv;
      r0 += r0val;
      value=0;
      for(int d=0; d<degree; ++d)
	value += Product[d];
      rstarval=(Expmy*value)/normalv;
      rstar += rstarval;
      single
	<<v
	<<'\t'
	<<r0val
	<<'\t'
	<<rstarval
	<<'\t'
	<<1.0e0-r0val-rstarval
	<<endl;
    }
  single.close();

  r0 /= VertexNumber;
  rstar /=VertexNumber;
  double yG=(yGvertex-0.5e0*yGedge)/VertexNumber;
  double MVC=(Evertex-0.5e0*Eedge)/VertexNumber;
  double complexity=Y*MVC-yG;

  ofstream average(output2name.c_str() );
  average
    <<VertexNumber
    << '\t'
    <<EdgeNumber
    << '\t'
    <<Y
    <<endl
    <<yG
    <<'\t'
    <<MVC
    <<'\t'
    <<complexity
    <<'\t'
    <<r0
    <<'\t'
    <<rstar
    <<'\t'
    <<1.0e0-r0-rstar
    <<endl<<endl
    ;
  average.close();

  return ;
}


int main(int argc, char ** argv)
{
  cout
    <<"mvcSP for single simple graph.\n"
    <<"Haijun Zhou (last modified 30.09.2012)\n\n"
    ;
  int rdseed  =23456789;
  cout
    <<"rdseed= "
    ;
  cout.flush();
  cin
    >>rdseed
    ;
  ZHJRANDOMv3 rdgenerator(rdseed);
  for(int i=0; i<rdseed; ++i)
    rdgenerator.rdflt();
  MVCsurveypropagation mvc(&rdgenerator);
  cout
    <<"problem instance "
    ;
  string gfile;
  cin
    >>gfile;
  cout
    <<"Edgenumber ="
    ;
  cout.flush();
  int edgenumber;
  cin
    >>edgenumber
    ;
  if(edgenumber<0)
    return -1;

  if(mvc.Graph(gfile, edgenumber) == false)
    return -1;

  bool repeat=true;
  cout
    <<"Message initialization (0/1)? "
    ;
  cout.flush();
  char itag;
  cin
    >>itag
    ;
  if(itag=='0')
    mvc.InitialMessage0();
  else
    mvc.InitialMessageR();
  int iterations;
  double yval=1;
  while(yval>=0)
    {
      cout
	<<"Y (negative to quit) "
	;
      cout.flush();
      cin
	>>yval
	;
      if(yval<0.0e0)
	continue;

      mvc.SetY(yval);

      repeat=true;
      while(repeat)
	{
	  double error;
	  cout<<"SP error ";
	  cin>>error;
	  cout<<"SP iterations ";
	  cin>>iterations;
	  double DampingFactor;
	  cout
	    <<"DampingFactor (0,1)?"
	    ;
	  cout.flush();
	  cin
	    >>
	    DampingFactor;
	  if(DampingFactor>1.0e0) DampingFactor=1.0e0;
	  else if(DampingFactor<0.01e0) DampingFactor=0.01e0;
	  mvc.SetDamping(DampingFactor);
	  if(mvc.SurveyPropagation(error,iterations))
	    repeat=false;
	}
      string filesingle, fileaverage;
      cout<<"Single file name: "; cout.flush();
      cin>>filesingle;
      cout<<"Average file name: "; cout.flush();
      cin>>fileaverage;
      mvc.Statistics(filesingle,fileaverage);
    }

  return 1;
}
