#include "./include/includes.h"
#include <float.h>
#include <omp.h>

/* memory */
wilson_vector *cg_p, *cg_mp, *cg_res, *cg_temp;
float *cg_p_32, *cg_mp_32, *cg_res_32, *cg_temp_32;
int cg_ishift;

__targetHost__ void malloc_cg64()
{
    MEMALIGN(cg_p,wilson_vector,sites_on_node);
    MEMALIGN(cg_mp,wilson_vector,sites_on_node);
    MEMALIGN(cg_res,wilson_vector,sites_on_node);
    MEMALIGN(cg_temp,wilson_vector,sites_on_node);
}
__targetHost__ void free_cg64()
{
    FREE(cg_p,wilson_vector,sites_on_node);
    FREE(cg_mp,wilson_vector,sites_on_node);
    FREE(cg_res,wilson_vector,sites_on_node);
    FREE(cg_temp,wilson_vector,sites_on_node);
}


//target copies of data structures
wilson_vector* t_src;
//wilson_vector* t_dest;

wilson_vector* t_cg_p;
wilson_vector* t_cg_mp;
wilson_vector* t_cg_temp;
wilson_vector* t_cg_res;
wilson_vector* t_dest;

half_wilson_vector** t_htmp;    
half_wilson_vector** t_htmp_prime;    

char*** t_gen_pt;    
int** t_neighbor;    
su3_matrix *t_gauge;


double* t_result;

double* t_pkp;

half_wilson_vector* htmp_prime[8];
double* sendbuf;
double* recvbuf;
double* t_sendbuf;
double* t_recvbuf;

__targetConst__ int t_sites_on_node;

extern int offnode_even[8];
extern int offnode_odd[8];
extern int **neighbor;

/* original cg */
__targetHost__ int congrad_64( wilson_vector *src, wilson_vector *dest, int maxniter, double myrsqmin)
{


  double t1, t2, t3, t4;

 

    int it = 0, i;		/*counter for iterations */
    site *s;
    double alpha, mybeta;

    double source_norm;
    double rsq, oldrsq, pkp;
    double flp;

    //targetdp initialisation
    void* tmpptr;
    int id;


#ifdef CUDA

    node0_printf("\n--- Running on GPU ---\n");
    //set the GPU device for this MPI task
#ifdef GPUSPN
#define GPUS_PER_NODE GPUSPN
#else
#define GPUS_PER_NODE 1
#endif

    node0_printf("(with a maximum of %d GPU(s) per physical node)\n\n",GPUS_PER_NODE);

    int devicenum=this_node%GPUS_PER_NODE;
    cudaSetDevice(devicenum);

    cudaDeviceSetCacheConfig(cudaFuncCachePreferL1);

    MPI_Barrier(MPI_COMM_WORLD);
    int devicenum_reported;
    cudaGetDevice(&devicenum_reported);
    printf("MPI task %d is running on GPU with ID %d\n",this_node,devicenum_reported);
    MPI_Barrier(MPI_COMM_WORLD);
#else

    int nThreads;
#pragma omp parallel
{
  nThreads=omp_get_num_threads();
}


// default data layout is Array of Structs of Short Arrays
 char* dataLayout="AoSoA"; 

#ifdef SoA
 dataLayout="SoA"; //Struct of Arrays
#endif

#ifdef AoS
 dataLayout="AoS"; //Arrays of Structs
#endif

 node0_printf("\n--- Running on CPU or Xeon Phi with %d OpenMP threads,\n a Virtual Vector Length of %d,\n and %s data layout ---\n\n",nThreads,VVL,dataLayout);

#endif


    //allocate target copies
    targetMalloc((void**) &t_gauge,sites_on_node*4*3*3*2*sizeof(double));  

    double* tmpptr1;
    // get gauge in right layout for congrad loop and copy to target
   double* tmpgauge = (double*) malloc(sites_on_node*4*3*3*2*sizeof(double));
    tmpptr1=(double*) gauge;
    int site_;
    #pragma omp parallel for
    for(site_=0;site_<sites_on_node;site_++){
    int dir_, colour1_, colour2_, reim_;
      for(dir_=0;dir_<4;dir_++)
	for(colour1_=0;colour1_<3;colour1_++)
	  for(colour2_=0;colour2_<3;colour2_++)
	    for(reim_=0;reim_<2;reim_++)
	tmpgauge[SU3MIH(site_,colour1_,colour2_,dir_,reim_)]=tmpptr1[((site_)*4*3*3*2+(dir_)*3*3*2+(colour1_)*3*2+(colour2_)*2+(reim_))];
    }
  
    copyToTarget(t_gauge,tmpgauge,sites_on_node*4*3*3*2*sizeof(double));

    free(tmpgauge);



    targetCalloc((void**) &t_cg_p,sites_on_node*sizeof(wilson_vector));   
    targetCalloc((void**) &t_cg_mp,sites_on_node*sizeof(wilson_vector));   
    targetCalloc((void**) &t_cg_temp,sites_on_node*sizeof(wilson_vector));   
    targetCalloc((void**) &t_cg_res,sites_on_node*sizeof(wilson_vector));   
    targetCalloc((void**) &t_src,sites_on_node*sizeof(wilson_vector));  
    targetCalloc((void**) &t_dest,sites_on_node*sizeof(wilson_vector));   
    targetCalloc((void**) &t_htmp,8*sizeof(half_wilson_vector*));  
    targetCalloc((void**) &t_htmp_prime,8*sizeof(half_wilson_vector*));  
    targetCalloc((void**) &t_gen_pt,8*sizeof(char**));  
    targetCalloc((void**) &t_neighbor,8*sizeof(int*));  
    targetCalloc((void**) &t_result,sites_on_node*sizeof(double));     
    targetCalloc((void**) &t_pkp,sizeof(double));  



    for(id=0;id<8;id++){

      targetCalloc((void**) &tmpptr,sites_on_node*sizeof(half_wilson_vector));
      copyToTarget(&(t_htmp[id]),&tmpptr,sizeof(half_wilson_vector*));


      targetCalloc((void**) &tmpptr,sites_on_node*sizeof(half_wilson_vector));
      copyToTarget(&(t_htmp_prime[id]),&tmpptr,sizeof(half_wilson_vector*));


      targetCalloc((void**) &tmpptr,sites_on_node*sizeof(half_wilson_vector));  
      copyToTarget(&(t_gen_pt[id]),&tmpptr,sizeof(char**));


      targetCalloc((void**) &tmpptr,sites_on_node*sizeof(int));  
      copyToTarget(&(t_neighbor[id]),&tmpptr,sizeof(int*));


    }


    copyConstToTarget(&t_sites_on_node,&sites_on_node,sizeof(int));

    for( id=0; id <8; id++){ 
      htmp_prime[id]=(half_wilson_vector*) malloc(sites_on_node*sizeof(half_wilson_vector));
    }


    
    int maxTranSize=0;
    for( id=0; id <8; id++){
      int tranSize=offnode_even[id]+offnode_odd[id];
      if (tranSize > maxTranSize)
	maxTranSize=tranSize;
    }

    if (maxTranSize > 0){
      sendbuf= (double*) malloc(maxTranSize*sizeof(half_wilson_vector));
      recvbuf= (double*) malloc(maxTranSize*sizeof(half_wilson_vector));

      targetMalloc((void**) &t_sendbuf,maxTranSize*sizeof(half_wilson_vector));
      targetMalloc((void**) &t_recvbuf,maxTranSize*sizeof(half_wilson_vector));
    }

    //end targetdp initialisation


    double mytime;
    mytime = -dclock(  );

    int nmatmul=0;

    malloc_cg64();

    //setup comms
    setup_dslash_comms();


    for( id=0; id < 8; id++){       
      copyFromTarget(&tmpptr,&(t_neighbor[id]),sizeof(int*));
      copyToTarget(tmpptr,neighbor[id],sites_on_node*sizeof(int));    
    }

    targetZero((double*) t_dest,sites_on_node*(sizeof(wilson_vector)/sizeof(double)));
    targetSetConstant((double*) t_src,1.,sites_on_node*(sizeof(wilson_vector)/sizeof(double)));
    
    // t_dest is zero, 
    // but the below is kept for consistency with the original benchmark
    multiply_fmat( t_dest, t_cg_temp, 1 );
    multiply_fmat( t_cg_temp, t_cg_mp, -1 );

    /*r=p=src-(M+M)*dest, rsq=|r|^2, source_norm=|src|^2 */
    source_norm = rsq = 0;

    sub_wilson_vector_lattice __targetLaunch__(sites_on_node) ( &( t_src[0] ), &( t_cg_mp[0] ), &( t_cg_res[0] ) );
    targetSynchronize();


    copyOnTarget(t_cg_p,t_cg_res,sites_on_node*sizeof(wilson_vector));
    

    magsq_wvec_lattice __targetLaunch__(sites_on_node) ( &(t_cg_res[0]),t_result);    targetSynchronize();
    
    rsq+=targetDoubleSum(t_result, sites_on_node);


    magsq_wvec_lattice __targetLaunch__(sites_on_node) ( &(t_src[0]),t_result);
    targetSynchronize();
    
    source_norm+=targetDoubleSum(t_result, sites_on_node);
    
    g_doublesum_KE( &rsq );
    g_doublesum_KE( &source_norm );

    
    double t5,t6;

    node0_printf("\n\nStarting congrad loop...\n");
    MPI_Barrier( MPI_COMM_WORLD );

    t5=omp_get_wtime();
    
    it=0;

    while (/*rsq>rsqstop &&*/ (nmatmul)<2*maxniter) //start of congrad loop
    {
      
      t3=omp_get_wtime();
      t1=omp_get_wtime();  

	double time;

        /* // rsq -> oldrsq  */
        oldrsq = rsq;

        /* // mp = M+M*p, pkp = p*mp  */
        multiply_fmat( t_cg_p, t_cg_temp, 1 );
        multiply_fmat( t_cg_temp, t_cg_mp, -1 );

#ifdef VERBOSE_TIMINGS
	t2=omp_get_wtime();node0_printf("multiply_fmat %1.16e s\n",t2-t1);t1=omp_get_wtime();  
#endif
	nmatmul+=2;

        pkp = 0.0;

	wvec_rdot_tdp_lattice __targetLaunch__(sites_on_node) (t_result, t_cg_p, t_cg_mp);
	targetSynchronize();
	
#ifdef VERBOSE_TIMINGS
	t2=omp_get_wtime();time=t2-t1;node0_printf("wvec_rdot %1.16e s %1.16e GB/s \n",time,sites_on_node*392./(time*1073741824.));t1=omp_get_wtime();  
#endif
	  
	pkp+=targetDoubleSum(t_result, sites_on_node);

#ifdef VERBOSE_TIMINGS
	t2=omp_get_wtime();node0_printf("target_doublesum %1.16e s\n",t2-t1);t1=omp_get_wtime();  
#endif

	//this is an MPI_Allreduce - do on host
        g_doublesum_KE( &pkp );

#ifdef VERBOSE_TIMINGS
	t2=omp_get_wtime();node0_printf("g_doublesum %1.16e s\n",t2-t1);t1=omp_get_wtime();  
#endif

	mybeta = rsq / pkp;

        rsq = 0.;


	scalar_mult_add_wvec_lattice __targetLaunch__(sites_on_node) (t_cg_res,t_cg_mp,-mybeta,t_cg_res);
	targetSynchronize();


#ifdef VERBOSE_TIMINGS
	t2=omp_get_wtime();time=t2-t1;node0_printf("scalar_mult_add_wvec %1.16e s %1.16e GB/s \n",time,sites_on_node*576./(time*1073741824.));t1=omp_get_wtime();  
#endif
 

	magsq_wvec_lattice __targetLaunch__(sites_on_node) ( &(t_cg_res[0]),t_result);
	targetSynchronize();

#ifdef VERBOSE_TIMINGS
	t2=omp_get_wtime();time=t2-t1;node0_printf("magsq_wvec %1.16e s %1.16e GB/s \n",time,sites_on_node*200./(time*1073741824.));t1=omp_get_wtime();  
#endif

	rsq+=targetDoubleSum(t_result, sites_on_node);

#ifdef VERBOSE_TIMINGS
	t2=omp_get_wtime();node0_printf("targetDoublesum %1.16e\n",t2-t1);t1=omp_get_wtime();  
#endif
        g_doublesum_KE( &rsq );

#ifdef VERBOSE_TIMINGS
	t2=omp_get_wtime();node0_printf("g doublesum %1.16e\n",t2-t1);t1=omp_get_wtime();  
#endif

        alpha = rsq / oldrsq;


	scalar_mult_add_wvec_lattice __targetLaunch__(sites_on_node) (t_cg_res,t_cg_p,alpha,t_cg_p);
	targetSynchronize();

#ifdef VERBOSE_TIMINGS
	t2=omp_get_wtime();time=t2-t1;node0_printf("scalar_mult_add_wvec %1.16e s %1.16e GB/s \n",time,sites_on_node*576./(time*1073741824.));t1=omp_get_wtime();  
#endif
                
        it++;

#ifdef VERBOSE_TIMINGS
	t2=omp_get_wtime();node0_printf("printf %1.16e\n",t2-t1);t1=omp_get_wtime();

	t4=omp_get_wtime();node0_printf("***full iter %1.16e s\n",t4-t3);
  
#endif

    } //end of congrad loop


    MPI_Barrier( MPI_COMM_WORLD );
    
    node0_printf("...Finished congrad loop\n\n");
    t6=omp_get_wtime();node0_printf("******BENCHMARK TIME %1.16e seconds****** \n\n\n",t6-t5);

    finalise_dslash_comms();

    if ((nmatmul)>=2*maxniter)
        node0_fprintf( file_o1, "WARNING congrad_orig: not converged after it=%i: mvm= %d %e > %e\n", 
                it, nmatmul, sqrt( rsq / source_norm ), sqrt( myrsqmin ) );

end:
    if( source_norm <= myrsqmin )
    {
        FORALLSITES( i, s ) clear_wvec( &dest[i] );
    }

    /* timing */
    mytime += dclock(  );

    t2=omp_get_wtime();

    //targetdp finalisation
    //clean up target data
    
    targetFree(t_cg_p);
    targetFree(t_cg_mp);
    targetFree(t_cg_temp);
    targetFree(t_cg_res); 
    targetFree(t_src); 
    targetFree(t_dest); 
    targetFree(t_gauge); 
    
    for(id=0;id<8;id++){
      copyFromTarget(&tmpptr,&(t_htmp[id]),sizeof(half_wilson_vector*));
      targetFree(tmpptr);

      copyFromTarget(&tmpptr,&(t_gen_pt[id]),sizeof(char**));
      targetFree(tmpptr);

      copyFromTarget(&tmpptr,&(t_neighbor[id]),sizeof(int*));
      targetFree(tmpptr);

    }

    targetFree(t_htmp);
    targetFree(t_neighbor);
    targetFree(t_gen_pt);
    targetFree(t_result);
    targetFree(t_pkp);


    for( id=0; id <8; id++){ 
      free(htmp_prime[id]);
    }

    if (maxTranSize > 0){
      free(sendbuf);
      free(recvbuf);
    
    targetFree(t_sendbuf);
    targetFree(t_recvbuf);
    }


    //end targetdp finalisation


    node0_fprintf(file_o1,"congrad_orig: end %d prec= %e mvm= %d time= %.3g\n",
           it,sqrt(rsq/source_norm),nmatmul,mytime);


    verbose_fprintf( file_o1, "congrad_orig: it= %d\t%.3g sec \n",
		     it, mytime);

    FILE* f;

    char ref_file[80];
    sprintf(ref_file,"output_ref/kernel_E.output.nx%dny%dnz%dnt%d.i%d.t%d",nx,ny,nz,nt,maxniter,numnodes_KE());

    if( mynode_KE(  ) == 0 ){
      printf( "\n\nValidating against %s:\n",ref_file );
      if( (f = fopen( ref_file, "r" )) == 0 )
	{
	  printf( "Can not validate: missing ref file\n" );
	}
      else{
	
	float val;
	
	char readname[80];
	while(1)
	  {
	    
	    fscanf( f, "%s", readname );
	    if( strcmp( readname,"congrad_orig:") == 0 ){
	      fscanf( f, "%s", readname );
	      if( strcmp( readname,"end") == 0 ){
		fscanf( f, "%s", readname );
		fscanf( f, "%s", readname );
		fscanf( f, "%s", readname );
		val=atof(readname);
		break;
	      }
	    }
	    
	  }
	
	if(fabs(val-sqrt(rsq/source_norm))<0.000001)
	  printf("VALIDATION PASSED\n");
	else
	  printf("VALIDATION FAILED: prec of %f does not equal ref value of %f \n",sqrt(rsq/source_norm),val);
	fclose(f);
      }
      
    }
    

    free_cg64();

    return nmatmul;
}

