#include <stdio.h>
#include "mpi.h"
#include "trap.h"
void Get_data(int me, int nproc,float* a_ptr, float* b_ptr, int* n_ptr){
  int src = 0, dst, tag;
    MPI_Status status;
    if(me == 0){
      printf("Enter a, b, and n\n");
      scanf("%f %f %d", a_ptr, b_ptr, n_ptr);
      for(dst = 1; dst < nproc; dst++){
        tag = 30;
        MPI_Send(a_ptr, 1, MPI_FLOAT, dst, tag, MPI_COMM_WORLD);
        tag = 31;
        MPI_Send(b_ptr, 1, MPI_FLOAT, dst, tag, MPI_COMM_WORLD);
        tag = 32;
        MPI_Send(n_ptr, 1, MPI_INT, dst, tag, MPI_COMM_WORLD);
      }
    } else {
      tag = 30;
      MPI_Recv(a_ptr, 1, MPI_FLOAT, src, tag, MPI_COMM_WORLD, &status);
      tag = 31;
      MPI_Recv(b_ptr, 1, MPI_FLOAT, src, tag, MPI_COMM_WORLD, &status);
      tag = 32;
      MPI_Recv(n_ptr, 1, MPI_INT, src, tag, MPI_COMM_WORLD, &status);
    }
}

int main(int argc, char** argv) {
  int me, nproc, n, local_n;          
  float h, a, b, local_a, local_b, integral, total;         
  int src, dst = 0, tag = 50;
  MPI_Status status;
  MPI_Init(&argc, &argv);
  MPI_Comm_rank(MPI_COMM_WORLD, &me);
  MPI_Comm_size(MPI_COMM_WORLD, &nproc);
  Get_data(me, nproc, &a, &b, &n);
  h = (b-a)/n;   
  local_n = n/nproc; 
  local_a = a + me*local_n*h;
  local_b = local_a + local_n*h;
  integral = Trap(local_a, local_b, local_n, h);
  if (me == 0) {
    total = integral;
    for(src = 1; src < nproc; src++){
      MPI_Recv(&integral, 1, MPI_FLOAT, src, tag, MPI_COMM_WORLD, &status);
      total += integral; }
  } else {   
    MPI_Send(&integral, 1, MPI_FLOAT, dst, tag, MPI_COMM_WORLD);
  }
  if (me == 0) {
    printf("With n = %d trapezoids, our estimate\n", n);
    printf("of the integral from %f to %f = %f\n", a, b, total); 
  }
  MPI_Finalize();
  return 0;
}