!===============================================================================
!
! BQCD -- Berlin Quantum ChromoDynamics program
!
! Author: Hinnerk Stueben <stueben@zib.de>
!
! Copyright (C) 1998-2005, Hinnerk Stueben, Zuse-Institut Berlin
!
!-------------------------------------------------------------------------------
!
! bqcd.F90 - main program and read/write of parameters
!
!-------------------------------------------------------------------------------
# include "defs.h"

!-------------------------------------------------------------------------------

! JuBE
! use kernel_a as a subroutine in the qcd-bench, this was the main function
! in the original code
subroutine kernel_a()

  use typedef_flags
  use typedef_para
  use module_input
  use module_function_decl
  implicit none
  
  type(type_para)                       :: para
  type(hmc_conf), dimension(MAX_TEMPER) :: conf
  type(type_flags)                      :: flags
  SECONDS                               :: time0, sekunden
  integer				:: kernel_number

  kernel_number = 0

! JuBE
! call jube initial function
  call jube_kernel_init(kernel_number)

! JuBE
! set the flags%input to the inputfile name: bqcd-input
  flags%input = "kernel_A.input"

  time0 = sekunden()  ! start/initialize timer

  TIMING_START(timing_bin_total)

  call comm_init()

! JuBE
! there is no need for the following call, we ignore all cmd line arguments, non
! of them but the input file name (set above) is relevant for the benchmark
! call get_flags(flags)

  call begin(UREC, "Job")
  call input_read(flags%input)
  call init_para(para, flags)
  call init_counter(para, flags)
  call init_ran(para, flags)
  call init_cooling(input%measure_cooling_list)

  call set_fmt_ensemble(para%n_temper)
  call check_fmt(para%run, para%n_temper, para%maxtraj, para%L(4) - 1)

  call init_common(para)
  call init_modules()

  call write_header(para)

  call init_flip_bc()
  call init_cg_para(para%cg_rest, para%cg_maxiter, para%cg_log)
  call init_cg_stat()
  call init_xbound()
  call init_confs(para, conf)

  call check_former(para%n_temper, conf)


! JuBE
! call jube kernel run function
  call jube_kernel_run()

  call mc(para, conf)
  !!call xbound_test()

! JuBE
! call jube kernel finalize function
  call jube_kernel_finalize()


  call conf_write(.true., para, conf)

  call write_counter(para%maxtraj)
  call write_ran()

  TIMING_STOP(timing_bin_total)

  call write_footer(time0)
  call end_A(UREC, "Job")

  call comm_finalize()

! JuBE
! call jube kernel end function
  call jube_kernel_end()

end subroutine kernel_a

!-------------------------------------------------------------------------------
subroutine init_para(para, flags)

  ! initialises module_para, module_switches and module_mre

  use typedef_flags
  use typedef_para
  use module_bqcd
  use module_input
  use module_mre
  use module_switches
  implicit none

  type(type_para)  :: para
  type(type_flags) :: flags
  integer          :: i
  logical          :: quenched, dynamical, clover, h_ext

  quenched   = .false.
  dynamical  = .false.
  clover     = .false.
  h_ext      = .false.

  para%run         = input%run
  para%L           = input%lattice
  para%NPE         = input%processes
  para%bc_fermions = input%boundary_conditions_fermions
  para%gamma_index = input%gamma_index
  para%n_temper    = input%ensembles
  para%nstd        = input%tempering_steps_without
  para%nforce      = input%hmc_accept_first
  para%ntraj       = input%mc_steps
  para%maxtraj     = input%mc_total_steps
  para%nsave       = input%mc_save_frequency
  para%c_cg_rest   = input%solver_rest
  para%cg_maxiter  = input%solver_maxiter
  para%cg_log      = input%solver_ignore_no_convergence
  mre_n_vec        = input%solver_mre_vectors

  call check_bc_fermions(para%bc_fermions, para%gamma_index)

  read(para%c_cg_rest, *) para%cg_rest

 if (para%n_temper <= 0) call die("init_para(): n_temper <= 0")
 if (para%n_temper > MAX_TEMPER) call die("init_para(): n_temper > max_temper")
  
  do i = 1, para%n_temper
     para%c_hmc(i)%beta        = input%beta(i)
     para%c_hmc(i)%kappa       = input%kappa(i)
     para%c_hmc(i)%csw         = input%csw(i)
     para%c_hmc(i)%h           = input%h(i)
     para%c_hmc(i)%traj_length = input%hmc_trajectory_length(i)
     para%c_hmc(i)%ntau        = input%hmc_steps(i)
     para%c_hmc(i)%rho         = input%hmc_rho(i)
     para%c_hmc(i)%m_scale     = input%hmc_m_scale(i)
     para%info_file(i)         = input%start_info_file(i)

     read(para%c_hmc(i)%beta,       *) para%hmc(i)%beta
     read(para%c_hmc(i)%kappa,      *) para%hmc(i)%kappa
     read(para%c_hmc(i)%csw,        *) para%hmc(i)%csw
     read(para%c_hmc(i)%h,          *) para%hmc(i)%h
     read(para%c_hmc(i)%traj_length,*) para%hmc(i)%traj_length
     read(para%c_hmc(i)%ntau,       *) para%hmc(i)%ntau
     read(para%c_hmc(i)%rho,        *) para%hmc(i)%rho
     read(para%c_hmc(i)%m_scale,    *) para%hmc(i)%m_scale

     if (para%hmc(i)%kappa == ZERO .and. para%hmc(i)%csw /= ZERO) then
        para%hmc(i)%csw_kappa = para%hmc(i)%csw
        para%c_hmc(i)%csw = "-1 (infinity)"
        para%hmc(i)%csw = -1
     else
        para%hmc(i)%csw_kappa = para%hmc(i)%csw * para%hmc(i)%kappa
        call check_csw(para%hmc(i)%beta,  para%hmc(i)%csw)
     endif

     para%hmc(i)%tau = para%hmc(i)%traj_length / para%hmc(i)%ntau

     write(para%c_hmc(i)%csw_kappa, "(f20.15)") para%hmc(i)%csw_kappa
     write(para%c_hmc(i)%tau,       "(f20.15)") para%hmc(i)%tau

     if (para%hmc(i)%kappa == ZERO .and. para%hmc(i)%csw == ZERO) then
        quenched = .true.
     else
        dynamical = .true.
     endif

     if (para%hmc(i)%csw /= ZERO) clover     = .true.
     if (para%hmc(i)%h   /= ZERO) h_ext      = .true.

     para%hmc(i)%model = input%hmc_model

     if (para%hmc(i)%model == "A" .and. para%hmc(i)%rho /= ZERO) then
        call warn("init_para(): model == A but rho /= 0")
     endif

     if (para%hmc(i)%model /= "A" .and. para%hmc(i)%rho == ZERO) then
        call warn("init_para(): model /= A but rho == 0")
     endif
  enddo

  select case (input%start_configuration)
     case ("hot");  para%start = START_HOT
     case ("cold"); para%start = START_COLD
     case ("file"); para%start = START_FILE
     case default  
        call die("init_para(): start_configuration must be {hot|cold|file}")
  end select

  select case (input%start_random)
     case ("random");  para%seed = -1
     case ("default"); para%seed = 0
     case default;     read(input%start_random, *) para%seed
  end select

  select case (input%tempering_swap_sequence)
     case ("up");     para%swap_seq = SWAP_UP
     case ("down");   para%swap_seq = SWAP_DOWN
     case ("random"); para%swap_seq = SWAP_RANDOM
     case default
       call die("init_para(): tempering_swap_sequence must be {up|down|random}")
  end select

 if (quenched .and. dynamical) call die("init_para(): quenched/dynamical mixed")

  if (para%nforce < 0) call die("init_para(): nforce < 0")

  if (flags%continuation_job) para%start = START_CONT

  
  switches%quenched   = quenched
  switches%dynamical  = dynamical
  switches%clover     = clover
  switches%h_ext      = h_ext
  switches%hasenbusch = (input%hmc_model /= "A")

  if (quenched) switches%hasenbusch = .false.

  switches%tempering             = .false.
  switches%measure_polyakov_loop = .false.
  switches%measure_traces        = .false.

 if (input%ensembles             >  1) switches%tempering             = .true.
 if (input%measure_polyakov_loop /= 0) switches%measure_polyakov_loop = .true.
 if (input%measure_traces        /= 0) switches%measure_traces        = .true.

  if (input%hmc_test == 0) then
     switches%hmc_test = .false.
  else
     switches%hmc_test = .true.
  endif

end subroutine init_para

!-------------------------------------------------------------------------------
subroutine init_counter(para, flags)

  use typedef_flags
  use typedef_para
  use module_counter
  use module_function_decl
  implicit none

  type(type_para)    :: para
  type(type_flags)   :: flags
  FILENAME, external :: count_file, stop_file

  if (f_exist(stop_file())) then
     call die("init_counter(): found stop file " // stop_file())
  endif

  counter%run = para%run
  counter%j_traj = 0

  if (flags%continuation_job) then
     open(UCOUNT, file = count_file(), action = "read", status = "old")
     read(UCOUNT, *) counter%run
     read(UCOUNT, *) counter%job
     read(UCOUNT, *) counter%traj
     close(UCOUNT)

     if (counter%run /= para%run) call die("init_counter(): RUN inconsistent")
     counter%job = counter%job + 1
  else
     counter%run = para%run
     counter%job = 1
     counter%traj = -para%nforce
  endif

end subroutine init_counter

!-------------------------------------------------------------------------------
subroutine write_counter(maxtraj)

  use module_counter
  use module_function_decl
  implicit none

  integer            :: maxtraj
  FILENAME, external :: count_file, stop_file

  if (my_pe() /= 0) return

  open(UCOUNT, file = count_file(), action = "write")
  write(UCOUNT, *) counter%run, " run"
  write(UCOUNT, *) counter%job, " job"
  write(UCOUNT, *) counter%traj, " traj"
  close(UCOUNT)

  if (counter%traj >= maxtraj) then
     open(UCOUNT, file = stop_file(), status = "unknown")
     close(UCOUNT)
  endif

end subroutine write_counter

!-------------------------------------------------------------------------------
subroutine write_header(para)

  use typedef_para
  use module_bqcd
  use module_counter
  use module_function_decl
  use module_input
  use module_mre
  use module_thread
  implicit none

  type(type_para)     :: para
  integer             :: i
  character(len = 50) :: fmt
  character(len = 4), external :: format_ensemble

  if (my_pe() == 0) then

     fmt = "(1x,a," // format_ensemble() // ",2a)"

     call begin(UREC, "Header")

    if (input%comment /= "") then
     write(UREC, 405) "Comment", trim(input%comment)
    endif

     write(UREC, 400) "Program", prog_name, prog_version
     write(UREC,   *) "Version_of_D ", version_of_d()
     write(UREC,   *) "Communication ", trim(comm_method())
     write(UREC,   *) "Run ", para%run
     write(UREC,   *) "Job ", counter%job
     write(UREC, 405) "Host", rechner()
     write(UREC, 400) "Date", datum(), uhrzeit()
     write(UREC, 410) "L          ", para%L
     write(UREC, 410) "NPE        ", para%NPE
     write(UREC, 410) "bc_fermions", para%bc_fermions
     write(UREC, 410) "gamma_index", para%gamma_index


     write(UREC,   *) "Threads ", n_thread
     write(UREC,   *) "Start   ", para%start

     if (para%start == START_FILE) then
        do i = 1, para%n_temper
           write(UREC, fmt) "StartConf_", i, " ", trim(para%info_file(i))
        enddo
     endif

     write(UREC,   *) "Seed    ", para%seed
     write(UREC,   *) "Swap_seq", para%swap_seq
     write(UREC,   *) "N_force ", para%nforce
     write(UREC,   *) "N_traj  ", para%ntraj
     write(UREC,   *) "N_save  ", para%nsave
     write(UREC,   *) "N_temper", para%n_temper

     do i = 1, para%n_temper
        write(UREC, fmt) "beta_", i, "        ", trim(para%c_hmc(i)%beta)
        write(UREC, fmt) "kappa_", i, "       ", trim(para%c_hmc(i)%kappa)
        write(UREC, fmt) "csw_", i, "         ", trim(para%c_hmc(i)%csw)
        write(UREC, fmt) "csw_kappa_", i, "   ", trim(para%c_hmc(i)%csw_kappa)
        write(UREC, fmt) "h_", i, "           ", trim(para%c_hmc(i)%h)
        write(UREC, fmt) "tau_", i, "         ", trim(para%c_hmc(i)%tau)
        write(UREC, fmt) "N_tau_", i, "       ", trim(para%c_hmc(i)%ntau)
        write(UREC, fmt) "traj_length_", i, " ", trim(para%c_hmc(i)%traj_length)
        write(UREC, fmt) "rho_", i, "         ", trim(para%c_hmc(i)%rho)
        write(UREC, fmt) "m_scale_", i, "     ", trim(para%c_hmc(i)%m_scale)
     enddo

     write(UREC,   *) "HMC_model ", para%hmc(1)%model
     write(UREC,   *) "REAL_kind ", RKIND
     write(UREC, 405) "CG_rest ", trim(para%c_cg_rest)
     write(UREC,   *) "MRE_vectors ", mre_n_vec

     call end_A(UREC, "Header")

400  format (3(1x,a))
405  format (2(1x,a))
410  format (1x,a,4i3)

  endif

end subroutine write_header

!-------------------------------------------------------------------------------
subroutine write_footer(time0)

  use module_function_decl
  use module_thread
  implicit none

  SEED    :: seed
  SECONDS :: time0, sekunden

  call ranget(seed)

  call begin(UREC, "Footer")

  if (my_pe() == 0) then
     write(UREC, 400) "Date", datum(), uhrzeit()
     write(UREC,   *) "Seed", seed
     write(UREC, 410) "CPU-Time", &
                      sekunden() - time0, "s on", num_pes() * n_thread, "CPUs"
  endif

400  format (3(1x,a))
410  format (1x,a,1x,f8.1,1x,a,1x,i5,1x,a)

  TIMING_WRITE(UREC)

  call end_A(UREC, "Footer")

end subroutine write_footer

!-------------------------------------------------------------------------------
subroutine get_flags(flags)

  use typedef_cksum
  use typedef_flags
  use module_bqcd
  use module_function_decl
  use module_input
  implicit none

  type(type_flags), intent(out) :: flags

  integer                       :: iarg, length, stat, narg
  integer, external             :: ipxfargc
  character(len = 2)            :: opt

  flags%continuation_job = .false.
  flags%show_version = .false.

  narg = ipxfargc()

  iarg = 1
  do while (iarg <= narg)
     call pxfgetarg(iarg, opt, length, stat)

     if (opt(1:1) == "-") then
        if (length > 2) call usage()

        select case (opt(2:2))
        case ("c")
           flags%continuation_job = .true.
           iarg = iarg + 1
        case ("I")
           call input_dump(6)
           call comm_finalize()
           stop
        case ("V")
           flags%show_version = .true.
           iarg = iarg + 1
        case default
           call usage
        end select
     else
        exit
     endif
  enddo
           
  if (flags%show_version) then
     call version()
     call comm_finalize()
     stop
  endif

  call take_arg(iarg, flags%input, narg)
  if (narg >= iarg) call usage

CONTAINS

  subroutine usage()
    implicit none
    call die("Usage: " // prog_name // " [-c] [-I] [-V] input")
  end subroutine usage

  subroutine version()
    implicit none

    if (my_pe() == 0) then
       write(6,*) "This is ", prog_name, " ", prog_version
       write(6,*) "   input format:    ", input_version
       write(6,*) "   conf info format:", conf_info_version
       write(6,*) "   MAX_TEMPER:      ", MAX_TEMPER
       write(6,*) "   real kind:       ", RKIND
       write(6,*) "   version of D:    ", version_of_d()
       write(6,*) "   D3: buffer vol:  ", get_d3_buffer_vol()  
       write(6,*) "   communication:   ", trim(comm_method())
    endif
  end subroutine version

  subroutine take_arg(iarg, arg, narg)
    implicit none
    integer, intent(inout)          :: iarg
    character(len = *), intent(out) :: arg
    integer, intent(in)             :: narg
    integer                         :: length, stat

    if (iarg > narg) call usage
    call pxfgetarg(iarg, arg, length, stat)
    if (length > len(arg)) then
       call die("get_flags(): " // arg // ": argument too long")
    endif
    iarg = iarg + 1
  end subroutine take_arg

end subroutine get_flags

!===============================================================================
