//
// BeliefPropReinforceSAT.cpp
//
// This program complements Section 6.3 of the book
// "Spin Glass and Message Passing" by Hai-Jun Zhou
// (Science Press, Beijing, 2015).
// Program last modified on 21.08.2012
//
// DESCRIPTION:
// Entropic belief-propagation-guided-reinforcement for a general SAT formula.
// The program tries to output a partial solution based on the field on each
// variable.
//
// (input) a cnf sat formula.
// (output) a (partial) solution to the formula.
//
// compile this program using
// c++ -O3 -o SATbpr.exe BeliefPropReinforceSAT.cpp
//
// Programmer:
// Hai-Jun Zhou
// Institute of Theoretical Physics, Chinese Academy of Sciences,
// Beijing 100190, China
// Webpage http://power.itp.ac.cn/~zhouhj
// Email zhouhj@itp.ac.cn
//
//

#include <cmath>
#include <exception>
#include <fstream>
#include <iostream>
#include <set>
#include <string>
#include <valarray>
#include "zhjrandom.h"

using namespace std;


struct etastruct
{
  double  p; //sum of soft fields to a variable from positive edges
  double  m; //sum of soft fields to a variable from negative edges
  int punsat; //sum of hard fields to a variable from positive edges
  int munsat; //sum of hard fields to a variable from negative edges
  etastruct(double, double, int, int);
};
etastruct::etastruct(double av=0, double bv=0, int cv=0, int dv=0)
{
  p=av;
  m=bv;
  punsat=cv;
  munsat=dv;
}


struct actionctovstruct
{
  int var; //the variable in the clause-variable interaction
  bool bar; //bar=true (spin=1 will sat the clause), bar=false (spin=-1 will sat the clause)
  double u; //message from a clause to a variable
  actionctovstruct(void);
};
actionctovstruct::actionctovstruct(void)
{
  var=0;
  bar=true;
  u=0;
}


struct clausestruct
{
  struct actionctovstruct *actionctov; //beginning address in the interaction array
  int degree; //total number of neighboring variables
  clausestruct(void);
};
clausestruct::clausestruct(void)
{
  actionctov=0;
  degree=0;
}


struct actionvtocstruct
{
  struct clausestruct *clause; //the clause in the clause-variable interaction
  int position; //which variable
  actionvtocstruct(void);
};
actionvtocstruct::actionvtocstruct(void)
{
  clause=0;
  position=0;
}


struct variablestruct
{
  int degree; //number of connections of the variable
  int spin; //spin value (spin=0, value not yet assigned)
  double reinforce_field; //reinforcement field
  struct actionvtocstruct *actionvtoc; //beginning address in the interaction array
  struct etastruct *etaptr; //hard fields and soft fields
  variablestruct(void);
};
variablestruct::variablestruct(void)
{
  actionvtoc=0;
  etaptr=0;
  degree=0;
  spin=0;
  reinforce_field=0;
}


//The following is a data structure used in updating clause<->variable messages
//in the most precise way and avoiding overflow
struct uarraystruct
{
  double minv;
  double expv;
  double resv;
  uarraystruct(double, double, double);
};
uarraystruct::uarraystruct(double mv=0, double ev=0, double rv=0)
{
  minv=mv;
  expv=ev;
  resv=rv;
}


class BPRforSAT
{
public:
  BPRforSAT(ZHJRANDOMv3* );
  bool Readformula( string&, const int);
  bool bpReinforce(const int,  string&,  string&);
  void Report_partialsolution( string&);
  
private:
  int Clause_number; //number of clauses
  int Clause_number_nontrivial; //number of nontrivial clauses
  int Iterations; //number of iterations for BP convergence
  int Max_clause_degree; //maximal degree of the clauses
  int Number_edges; //number of clause-variable interactions
  int Number_freespin; //number of non-isolated variables
  int Number_isolated_variables; //number of isolated variables
  int Number_unsat_clauses; //number of violated clauses
  int Variable_number; //number of variables

  double Iteration_steps_total; //total number of iterations
  double Reinforce_delta; //magnitude of reinforcement
  //  double Reinforce_field_max; //cut-off of maximal external field
  double Reinforce_lambda;
  double no_Reinforce_probability; //probability of not performing reinforcement
  ZHJRANDOMv3 *rdptr; //random number generator

  set<int> Isolated_variables; //set of isolated variables
  set<int> Free_variables; //set of non-isolated variables

  valarray<actionctovstruct> Alledgesctov; //clause -> variable array
  valarray<actionvtocstruct> Alledgesvtoc; //variable -> clause array
  valarray<clausestruct> Clauses; //clause array
  valarray<etastruct> Etaset; //fields array
  valarray<int>  Permutation; //belief-propagation permutation array
  valarray<uarraystruct> Uarray; //array used in BP iteration
  valarray<variablestruct> Variables; //variable array

  void Update_u(struct clausestruct*);
  void Iterate(void);
  void Compute_eta(void);
  void Compute_eta(struct variablestruct* );
};


int main(int argc, char ** argv)
{
  ifstream inputfile("4sata9p46n20k.cnf02.input");
  unsigned int rdseed;
  inputfile
    >>rdseed;
  ZHJRANDOMv3 rdgenerator(rdseed);
  for(unsigned int i=0; i<rdseed; ++i) rdgenerator.rdflt();
  BPRforSAT sp_zhj(&rdgenerator);
  string formula;
  inputfile
    >>formula;
  int clnumber;
  inputfile>>clnumber;
  if(sp_zhj.Readformula(formula, clnumber) == false)
    {
      inputfile.close();
      return -1;
    }
  string solution;
  inputfile
    >>solution;
  string trajectory;
  inputfile
    >>trajectory;
  inputfile.close();

  int    iterations_long  = 20000;
  sp_zhj.bpReinforce(iterations_long, trajectory, solution);
  return 1;
}


//constructor for BPRforSAT class
BPRforSAT::BPRforSAT(ZHJRANDOMv3 *rdp)
{
  Iterations=50;
  Number_unsat_clauses=0;
  rdptr=rdp;
  Reinforce_lambda=0.01;
  Reinforce_delta=0.005;
  no_Reinforce_probability=1.0;
  //  Reinforce_field_max=1.23875;
  Iteration_steps_total=0;
  return ;
}


//Read formula in CNF form
bool BPRforSAT::Readformula(
			     string& filename, //the formula file
			    int clnumber //reading the first clnumber clauses
			    )
{
  ifstream source;
  source.open(filename.c_str());
  if( !source.good() )
    {
      cerr<<"Error in cnf formula input (file possibly non-existent).\n";
      return false;
    }
  string abcd1, abcd2;
  source
    >>abcd1
    >> abcd2
    >> Variable_number
    >> Clause_number;
  if(Clause_number<clnumber)
    {
      cerr<<"Not so many clauses!\n";
      source.close();
      return false;
    }
  Clause_number=clnumber;
  try { Variables.resize(Variable_number+1); } catch(bad_alloc) { return false; }
  try { Clauses.resize(Clause_number+1); } catch(bad_alloc) { return false; }
  try { Etaset.resize(Variable_number); } catch(bad_alloc) { return false; }
  //first pass for counting
  int edge_index=0;
  Max_clause_degree=0;
  Clause_number_nontrivial=0;
  set<int> neighbors;
  bool satisfied;
  int var;
  for(int m=1; m<=Clause_number; ++m)
    {
      neighbors.clear();
      satisfied=false;
      source
	>>var;
      while(var!=0)
	{
	  if(abs(var)>Variable_number)
	    {
	      cerr
		<<"Too many variables.\n";
	      source.close();
	      return false;
	    }
	  if(neighbors.find(var)==neighbors.end())
	    {
	      neighbors.insert(var);
	      if(neighbors.find(-var)!=neighbors.end())
		satisfied=true;
	    }
	  source
	    >>var;
	}
      if(!satisfied)
	{
	  ++Clause_number_nontrivial;
	  for(set<int>::const_iterator sci=neighbors.begin(); sci!=neighbors.end(); ++sci)
	    {
	      var=*sci;
	      if(var<0) var *=-1;
	      ++(Variables[var].degree);
	      ++edge_index;
	    }
	  Clauses[Clause_number_nontrivial].degree=neighbors.size();
	  if(Max_clause_degree<neighbors.size()) Max_clause_degree=neighbors.size();
	}
    }
  source.close();
  Number_edges=edge_index;
  try { Alledgesctov.resize(Number_edges); } catch(bad_alloc) { return false;}
  try { Alledgesvtoc.resize(Number_edges); } catch(bad_alloc) { return false;}
  struct actionvtocstruct *clause_ptr=&Alledgesvtoc[0];
  struct etastruct *eta_ptr=&Etaset[0];
  Number_isolated_variables=0;
  Free_variables.clear();
  Isolated_variables.clear();
  for(int var=1; var<=Variable_number; ++var)
    {
      Variables[var].spin=1;
      if(Variables[var].degree)
	{
	  Free_variables.insert(var);
	  Variables[var].actionvtoc=clause_ptr;
	  clause_ptr += Variables[var].degree;
	  Variables[var].etaptr=eta_ptr;
	  ++eta_ptr;
	  Variables[var].degree=0;
	}
      else
	{
	  Isolated_variables.insert(var);
	  ++Number_isolated_variables;
	}
    }
  struct actionctovstruct *variable_ptr=&Alledgesctov[0];
  for(int m=1; m<=Clause_number_nontrivial; ++m)
    {
      Clauses[m].actionctov=variable_ptr;
      variable_ptr += Clauses[m].degree;
    }

  //second pass to do the actual reading
  source.open(filename.c_str());
  source
    >>abcd1
    >> abcd2
    >> Variable_number
    >> Clause_number;
  Clause_number=clnumber;
  Clause_number_nontrivial=0;
  for(int m=1; m<=Clause_number; ++m)
    {
      neighbors.clear();
      satisfied=false;
      source
	>>var;
      while(var!=0)
	{
	  if(neighbors.find(var)==neighbors.end())
	    {
	      neighbors.insert(var);
	      if(neighbors.find(-var)!=neighbors.end())
		satisfied=true;
	    }
	  source
	    >>var;
	}
      if(!satisfied)
	{
	  int size=0;
	  ++Clause_number_nontrivial;
	  for(set<int>::const_iterator sci=neighbors.begin(); sci!=neighbors.end(); ++sci)
	    {
	      Clauses[Clause_number_nontrivial].actionctov[size].bar= ( *sci >0? true : false);
	      var= *sci;
	      if(var<0) var *= -1;
	      Clauses[Clause_number_nontrivial].actionctov[size].var=var;
	      Variables[var].actionvtoc[Variables[var].degree].clause
		= &Clauses[Clause_number_nontrivial];
	      Variables[var].actionvtoc[Variables[var].degree++].position=size;
	      ++size;
	    }
	}
    }
  source.close();
  try { Uarray.resize(Max_clause_degree); } catch(bad_alloc) { return false; }
  Number_freespin=Variable_number-Number_isolated_variables;
  if(Number_isolated_variables>0)
    for(set<int>::const_iterator sci=Isolated_variables.begin(); sci!=Isolated_variables.end(); ++sci)
      Variables[*sci ].spin=(rdptr->rdflt()<=0.5? 1:-1);
  try { Permutation.resize(Clause_number_nontrivial) ; } catch(bad_alloc) { return false; }
  for(int cl=0; cl<Clause_number_nontrivial; ++cl)
    Permutation[cl]=cl+1;
  Number_unsat_clauses=Clause_number_nontrivial;
  return true;
}


//update the message on a variable node and update the reinforcement field
void BPRforSAT::Compute_eta(struct variablestruct *varptr)
{
  varptr->etaptr->p=0;
  varptr->etaptr->m=0;
  varptr->etaptr->punsat=0;
  varptr->etaptr->munsat=0;
  struct actionvtocstruct *actionvtoc=varptr->actionvtoc;
  struct actionctovstruct *edgeptr;
  for(int i=0; i<varptr->degree; ++i,++actionvtoc)
    {
      edgeptr=(actionvtoc->clause->actionctov)+(actionvtoc->position);
      if(edgeptr->u<=0)
	{
	  if(edgeptr->bar) varptr->etaptr->p += edgeptr->u;
	  else varptr->etaptr->m += edgeptr->u;
	}
      else // u=1 indicates a warning
	{
	  if(edgeptr->bar) ++(varptr->etaptr->punsat);
	  else ++(varptr->etaptr->munsat);
	}
    }
  if(varptr->etaptr->punsat == varptr->etaptr->munsat) //in case variable unfrozen
    {
      double diffv=varptr->etaptr->m - varptr->etaptr->p;
      if(rdptr->rdflt()>no_Reinforce_probability)
	{
	  if(diffv>=0)
	    {
	      varptr->reinforce_field += Reinforce_delta;
	      //	      if(varptr->reinforce_field>Reinforce_field_max)
	      //		varptr->reinforce_field=Reinforce_field_max;
	    }
	  else
	    {
	      varptr->reinforce_field -= Reinforce_delta;
	      //	      if(varptr->reinforce_field < -Reinforce_field_max)
	      //		varptr->reinforce_field= -Reinforce_field_max;
	    }
	}
      varptr->spin = ((diffv + varptr->reinforce_field) > 0 ? 1 : -1);
    }
  else if(varptr->etaptr->punsat > varptr->etaptr->munsat)
    varptr->spin=1;
  else
    varptr->spin=-1;
  return ;
}


//update the messages of all nontrivial variables
void BPRforSAT::Compute_eta(void)
{
  for(set<int>::const_iterator sci=Free_variables.begin(); sci!=Free_variables.end(); ++sci)
    Compute_eta(&Variables[*sci]);
  return ;
}


//updates all eta's and etavtoc's in clause cl
void BPRforSAT::Update_u(struct clausestruct *clause)
{
  /*
    Uarray[i] is an interval data used for precise computation of clause to variable message
    u_{a->i}=log[1-\prod_{j\in \partial a\backslash i}(1-J_a^j m_{j->a})/2 ].
    m_{j->a}=(Nja-Pja)/(Nja+Pja), with
    log(Nja)=\sum\limits_{b\in \partial j^{-} \backslash a} exp(u_{b->j})
    log(Pja)=\sum\limits_{b\in \partial j^{+} \backslash a} exp(u_{b->j})
  */
  for(int dd=0; dd<clause->degree; ++dd) Uarray[dd].minv=0;
  struct actionctovstruct *actionctov=clause->actionctov;
  struct etastruct *etaptr;
  bool satisfied=false;
  for(int dd=0; dd<clause->degree; ++dd,++actionctov)
    {
      if(!satisfied)
	if(actionctov->bar==(Variables[actionctov->var].spin==1))
	  satisfied=true;
      int I_ja=actionctov->bar ? 1: -1;
      etaptr=Variables[actionctov->var].etaptr;
      int por=I_ja*(etaptr->punsat - etaptr->munsat);
      double eta_ja = etaptr->m - etaptr->p + Variables[actionctov->var].reinforce_field;
      if(actionctov->u<=0) //no warning
	{
	  if(por==0) //variable unfrozen
	    {
	      eta_ja += I_ja * actionctov->u;
	      if(eta_ja>=0) eta_ja += 2.0; //a trick: eta_ja>=2 means positive eta_ja
	      else eta_ja -= 2.0; //trick: eta_ja<=-2 means negative eta_ja
	    }
	  else //variable frozen
	    eta_ja = por>0? I_ja : -I_ja; //trick: eta_ja=1 (positive frozen), =-1 (negative frozen)
	}
      else // actionctov->u=1 (warning)
	{
	  if(por==1)
	    {
	      if(eta_ja>=0) eta_ja += 2.0;
	      else eta_ja -= 2.0;
	    }
	  else eta_ja=por>1 ? I_ja : -I_ja;
	}
      if(eta_ja<0)
	{
	  eta_ja *= -1;
	  I_ja *= -1;
	} //eta_ja=2*(cavity field) on variable j, it is positive with respect to I_ja
      if(eta_ja>1.5) //variable j not frozen in absence of a, (eta_ja>=2)
	{
	  double val=1.0e0+exp(2.0e0-eta_ja); //m_{j->a}=[1-exp(2-eta_ja)]/[1+exp(2-eta_ja)]
	  for(int bb=0; bb<clause->degree; ++bb)
	    if(bb!=dd)
	      {
		double minv=Uarray[bb].minv;
		if(minv>1.5) //minv>=2
		  {
		    if(I_ja>0) Uarray[bb].expv += eta_ja-2.0e0; 
		    /* Uarray[i].expv= -\sum_{j\neq i} log[(1-I_aj)/2+e^{2-eta_ja} (1+I_aj)/2] */
		    /* \prod_{j\neq i}(1+e^{2-eta_ja})=1+exp(-Uarray[i].minv)*Uarray[i].resv */
		    if(minv<=eta_ja)
		      Uarray[bb].resv=val*Uarray[bb].resv+exp(minv-eta_ja);
		    else
		      {
			Uarray[bb].resv=1.0+exp(eta_ja-minv)*val*Uarray[bb].resv;
			Uarray[bb].minv=eta_ja;
		      }
		  }
		else if(minv < 0.5) // not yet initialized (minv=0)
		  Uarray[bb]=uarraystruct(eta_ja,(I_ja>0? eta_ja-2.0 : 0),1); //min,expv,resv
		//else  min=1, no need to proceed (because clause has been satisfied)
	      }
	}
      else // eta_ja=1, frozen (the direction is defined as positive)
	{
	  if(I_ja>0)
	    {
	      for(int bb=0; bb<clause->degree; ++bb)
		if(bb!=dd)
		  Uarray[bb].minv=1; //clause satisfied u_{b->i}=0
	    }
	  //else I_ja<0, clause can not satisfied by variable
	}
    }
  actionctov=clause->actionctov;
  for(int dd=0; dd<clause->degree; ++dd,++actionctov)
    {
      double new_u=0; //=log(1.0e0-\prod_{j\neq i} [(1-I_aj m_{j->a})/2])
      etaptr=Variables[actionctov->var].etaptr;
      if(Uarray[dd].minv>=1.5e0) //minv>=2
	{
	  if(Uarray[dd].expv>0)
	    new_u=log(1.0e0-exp(-Uarray[dd].expv)/(1.0e0+exp(2.0e0-Uarray[dd].minv)*Uarray[dd].resv));
	  else new_u=2.0-Uarray[dd].minv+log(Uarray[dd].resv/
					     (1.0e0+exp(2.0e0-Uarray[dd].minv)*Uarray[dd].resv));
	  if(actionctov->u<=0) //normal message
	    {
	      if(actionctov->bar) etaptr->p += new_u-actionctov->u;
	      else                etaptr->m += new_u-actionctov->u;
	    }
	  else //warning
	    {
	      if(actionctov->bar)
		{
		  --(etaptr->punsat);
		  etaptr->p += new_u;
		}
	      else
		{
		  --(etaptr->munsat);
		  etaptr->m += new_u;
		}
	    }
	}
      else if(Uarray[dd].minv>=0.5) //=1, clause satisfied
	{
	  new_u=0;
	  if(actionctov->u<=0)
	    {
	      if(actionctov->bar)
		etaptr->p -= actionctov->u;
	      else
		etaptr->m -= actionctov->u;
	    }
	  else
	    {
	      if(actionctov->bar) --(etaptr->punsat);
	      else                --(etaptr->munsat);
	    }
	}
      else // Uarray[dd].minv=0, warning
	{
	  new_u=1.0;
	  if(actionctov->u <=0)
	    {
	      if(actionctov->bar)
		{
		  ++(etaptr->punsat);
		  etaptr->p -= actionctov->u;
		}
	      else
		{
		  ++(etaptr->munsat);
		  etaptr->m -= actionctov->u;
		}
	    }
	}
      actionctov->u = new_u;
    }
  if(!satisfied)
    ++Number_unsat_clauses;
  return ;
}


//update u's of clauses in a random permuted order
void BPRforSAT::Iterate(void)
{
  Number_unsat_clauses=0;
  for(int quant=Clause_number_nontrivial; quant>0; --quant)
    {
      int ii=static_cast<int>(quant*rdptr->rdflt());
      int cl=Permutation[ii];
      if(ii<(quant-1))
	{
	  Permutation[ii]=Permutation[quant-1];
	  Permutation[quant-1]=cl;
	}
      Update_u(&Clauses[cl]);
    }
  Compute_eta();
  return ;
}


//solution searching
bool BPRforSAT::bpReinforce(
			    const int newiterations, //number of iterations at each energy level
			    const string& logfilename, //trajector file
			     string& solution //solution file
			    )
{
  ofstream output(logfilename.c_str(), ios_base::app);
  Iterations=newiterations;
  int old_num_unsat_clauses=Number_unsat_clauses;
  Compute_eta();
  int iter=0;
  do
    {
      Iteration_steps_total += 1;
      no_Reinforce_probability = 1.0e0/pow(1.0e0*iter,Reinforce_lambda);
      Iterate();
      if(Number_unsat_clauses<old_num_unsat_clauses)
	{
	  output<<Iteration_steps_total<<'\t'<<Number_unsat_clauses<<endl;
	  old_num_unsat_clauses=Number_unsat_clauses;
          Report_partialsolution(solution);
	  iter=0;
	}
      ++iter;
    }
  while (iter<=Iterations && Number_unsat_clauses>0);
  output.close();
  if(Number_unsat_clauses==0)
    {
      cerr
	<<"Solution found. Total iteration steps is "
	<<Iteration_steps_total
	<<endl;
      Report_partialsolution(solution);
      return true;
    }
  else
    {
      cerr
	<<"Fail to find a solution. Total iteration steps is "
	<<Iteration_steps_total
	<<endl;
      return false;
    }
}


//Report a spin configuration
void BPRforSAT::Report_partialsolution(
				        string& outputfile
				       )
{
  ofstream output(outputfile.c_str());
  output
    <<"Number of UNSAT clauses is "
    <<Number_unsat_clauses
    <<endl<<endl;
  for(int v=1; v<=Variable_number; ++v)
    output << v*Variables[v].spin <<endl;
  output.close();
  return ;
}
