#include <stdio.h>
#include <string.h>
#include <assert.h>
#include "ast.h"


/* constructor functions for asts */

static void *checked_malloc(int i) {
  void *malloc(int);
  void *p = malloc(i);
  assert(p);
  return p;
}


Program mk_Program(Body b) 
{
  Program p = checked_malloc(sizeof(*p));
  p->b = b;
  return p;
}

Body mk_Body(int line, DecList dl, Stat s) 
{
  Body b =  checked_malloc(sizeof(*b));
  b->line = line;
  b->dl = dl;
  b->s = s;
  return b;
}

DecList mk_DecList(Dec d, DecList next) 
{
  DecList dl = checked_malloc(sizeof(*dl));
  dl->d = d;
  dl->next = next;
  return dl;
}

Dec mk_VarDecs(VarDecList vl) 
{
  Dec d = checked_malloc(sizeof(*d));
  d->kind = VarDecs;
  d->u.vardecs.vl = vl;
  return d;
}

Dec mk_TypeDecs(TypeDecList tl) 
{
  Dec d = checked_malloc(sizeof(*d));
  d->kind = TypeDecs;
  d->u.typedecs.tl = tl;
  return d;
}

Dec mk_ProcDecs(ProcDecList pl) 
{
  Dec d = checked_malloc(sizeof(*d));
  d->kind = ProcDecs;
  d->u.procdecs.pl = pl;
  return d;
}

VarDecList mk_VarDecList(VarDec v, VarDecList next)
{
  VarDecList vdl = checked_malloc(sizeof(*vdl));
  vdl->v = v;
  vdl->next = next;
  return vdl;
}

VarDec mk_VarDec(int line,id i, Type t, Exp e) 
{
  VarDec vd = checked_malloc(sizeof(*vd));
  vd->line = line;
  vd->i = i;
  vd->t = t;
  vd->e = e;
  return vd;
}

TypeDecList mk_TypeDecList(TypeDec t, TypeDecList next)
{
  TypeDecList tdl = checked_malloc(sizeof(*tdl));
  tdl->t = t;
  tdl->next = next;
  return tdl;
}


TypeDec mk_TypeDec(int line, id i, Type t)
{
  TypeDec td = checked_malloc(sizeof(*td));
  td->line = line;
  td->i = i;
  td->t = t;
  return td;
}

ProcDecList mk_ProcDecList(ProcDec p, ProcDecList next)
{
  ProcDecList pdl = checked_malloc(sizeof(*pdl));
  pdl->p = p;
  pdl->next = next;
  return pdl;
}

ProcDec mk_ProcDec(int line, id i, FormalParamList fl, Type t, Body b)
{
  ProcDec pd = checked_malloc(sizeof(*pd));
  pd->line = line;
  pd->i = i;
  pd->fl = fl;
  pd->t = t;
  pd->b = b;
  return pd;
}

Type mk_NamedType(int line,id i) 
{
  Type t = checked_malloc(sizeof(*t));
  t->line = line;
  t->kind = NamedTyp;
  t->u.namedtyp.i = i;
  return t;
}

Type mk_ArrayType(int line,Type t1)
{
  Type t = checked_malloc(sizeof(*t));
  t->line = line;
  t->kind = ArrayTyp;
  t->u.arraytyp.t = t1;
  return t;
}

Type mk_RecordType(int line,ComponentList cl)
{
  Type t = checked_malloc(sizeof(*t));
  t->line = line;
  t->kind = RecordTyp;
  t->u.recordtyp.cl = cl;
  return t;
}

ComponentList mk_ComponentList(Component c, ComponentList next)
{
  ComponentList cl = checked_malloc(sizeof(*cl));
  cl->c = c;
  cl->next = next;
  return cl;
}

Component mk_Component(int line, id i, Type t)
{
  Component c = checked_malloc(sizeof(*c));
  c->line = line;
  c->i = i;
  c->t = t;
  return c;
}

FormalParamList mk_FormalParamList(FormalParam f, FormalParamList next)
{
  FormalParamList fpl = checked_malloc(sizeof(*fpl));
  fpl->f = f;
  fpl->next = next;
  return fpl;
}

FormalParam mk_FormalParam(int line, id i, Type t) 
{
  FormalParam fp = checked_malloc(sizeof(*fp));
  fp->line = line;
  fp->i = i;
  fp->t = t;
  return fp;
}

StatList mk_StatList(Stat s, StatList next)
{
  StatList sl = checked_malloc(sizeof(*sl));
  sl->s = s;
  sl->next = next;
  return sl;
}

Stat mk_AssignStat(int line, Lvalue l, Exp e) 
{
  Stat s = checked_malloc(sizeof(*s));
  s->line = line;
  s->kind = AssignSt;
  s->u.assignst.l = l;
  s->u.assignst.e = e;
  return s;
}

Stat mk_CallStat(int line, id i, ExpList el)
{
  Stat s = checked_malloc(sizeof(*s));
  s->line = line;
  s->kind = CallSt;
  s->u.callst.i = i;
  s->u.callst.el = el;
  return s;
}


Stat mk_ReadStat(int line, LvalueList ll)
{
  Stat s = checked_malloc(sizeof(*s));
  s->line = line;
  s->kind = ReadSt;
  s->u.readst.ll = ll;
  return s;
}


Stat mk_WriteStat(int line, ExpList el)
{
  Stat s = checked_malloc(sizeof(*s));
  s->line = line;
  s->kind = WriteSt;
  s->u.writest.el = el;
  return s;
}


Stat mk_IfStat(int line, Exp e, Stat s1, Stat s2)
{
  Stat s = checked_malloc(sizeof(*s));
  s->line = line;
  s->kind = IfSt;
  s->u.ifst.e = e;
  s->u.ifst.s1 = s1;
  s->u.ifst.s2 = s2;
  return s;
}

Stat mk_WhileStat(int line, Exp e, Stat s1)
{
  Stat s = checked_malloc(sizeof(*s));
  s->line = line;
  s->kind = WhileSt;
  s->u.whilest.e = e;
  s->u.whilest.s = s1;
  return s;
}

Stat mk_LoopStat(int line, Stat s1)
{
  Stat s = checked_malloc(sizeof(*s));
  s->line = line;
  s->kind = LoopSt;
  s->u.loopst.s = s1;
  return s;
}
  

Stat mk_ForStat(int line, id i, Exp e1, Exp e2, Exp e3, Stat s1)
{
  Stat s = checked_malloc(sizeof(*s));
  s->line = line;
  s->kind = ForSt;
  s->u.forst.i = i;
  s->u.forst.e1 = e1;
  s->u.forst.e2 = e2;
  s->u.forst.e3 = e3;
  s->u.forst.s = s1;
  return s;
}

Stat mk_ExitStat(int line)
{
  Stat s = checked_malloc(sizeof(*s));
  s->line = line;
  s->kind = ExitSt;
  return s;
}

Stat mk_RetStat(int line, Exp e)
{
  Stat s = checked_malloc(sizeof(*s));
  s->line = line;
  s->kind = RetSt;
  s->u.retst.e = e;
  return s;
}

Stat mk_SeqStat(StatList sl)
{
  Stat s = checked_malloc(sizeof(*s));
  s->kind = SeqSt;
  s->u.seqst.sl = sl;
  return s;
}

ExpList mk_ExpList(Exp e, ExpList next)
{
  ExpList el = checked_malloc(sizeof(*el));
  el->e = e;
  el->next = next;
  return el;
}

Exp mk_BinOpExp(int line, binop b, Exp e1, Exp e2) 
{
  Exp e = checked_malloc(sizeof(*e));
  e->line = line;
  e->kind = BinOpExp;
  e->u.binopexp.b = b;
  e->u.binopexp.e1 = e1;
  e->u.binopexp.e2 = e2;
  return e;
}

Exp mk_UnOpExp(int line, unop u, Exp e1)
{
  Exp e = checked_malloc(sizeof(*e));
  e->line = line;
  e->kind = UnOpExp;
  e->u.unopexp.u = u;
  e->u.unopexp.e = e1;
  return e;
}


Exp mk_LvalExp(int line, Lvalue l)
{
  Exp e = checked_malloc(sizeof(*e));
  e->line = line;
  e->kind = LvalExp;
  e->u.lvalexp.l = l;
  return e;
}


Exp mk_CallExp(int line, id i, ExpList el)
{
  Exp e = checked_malloc(sizeof(*e));
  e->line = line;
  e->kind = CallExp;
  e->u.callexp.i = i;
  e->u.callexp.el = el;
  return e;
}

Exp mk_ArrayExp(int line, id i, ArrayInitList al)
{
  Exp e = checked_malloc(sizeof(*e));
  e->line = line;
  e->kind = ArrayExp;
  e->u.arrayexp.i = i;
  e->u.arrayexp.al = al;
  return e;
}

Exp mk_RecordExp(int line, id i, RecordInitList rl)
{
  Exp e = checked_malloc(sizeof(*e));
  e->line = line;
  e->kind = RecordExp;
  e->u.recordexp.i = i;
  e->u.recordexp.rl = rl;
  return e;
}


Exp mk_IntConstExp(int line, int i)
{
  Exp e = checked_malloc(sizeof(*e));
  e->line = line;
  e->kind = IntConst;
  e->u.intconstexp.i = i;
  return e;
}

Exp mk_RealConstExp(int line, string r)
{
  Exp e = checked_malloc(sizeof(*e));
  e->line = line;
  e->kind = RealConst;
  e->u.realconstexp.r = r;
  return e;
}


Exp mk_StringConstExp(int line, string c)
{
  Exp e = checked_malloc(sizeof(*e));
  e->line = line;
  e->kind = StringConst;
  e->u.stringconstexp.c = c;
  return e;
}

ArrayInitList mk_ArrayInitList(ArrayInit a, ArrayInitList next)
{
  ArrayInitList ail = checked_malloc(sizeof(*ail));
  ail->a = a;
  ail->next = next;
  return ail;
}

ArrayInit mk_ArrayInit(Exp e1, Exp e2) 
{
  ArrayInit ai = checked_malloc(sizeof(*ai));
  ai->e1 = e1;
  ai->e2 = e2;
  return ai;
}

RecordInitList mk_RecordInitList(RecordInit r, RecordInitList next)
{
  RecordInitList ril = checked_malloc(sizeof(*ril));
  ril->r = r;
  ril->next = next;
  return ril;
}

RecordInit mk_RecordInit(id i, Exp e)
{
  RecordInit ri = checked_malloc(sizeof(*ri));
  ri->i = i;
  ri->e = e;
  return ri;
}

LvalueList mk_LvalueList(Lvalue l, LvalueList next)
{
  LvalueList ll = checked_malloc(sizeof(*ll));
  ll->l = l;
  ll->next = next;
  return ll;
}

Lvalue mk_VarLvalue(int line, id i)
{
  Lvalue l = checked_malloc(sizeof(*l));
  l->line = line;
  l->kind = Var;
  l->u.var.i = i;
  return l;
}

Lvalue mk_ArrayDerefLvalue(int line, Lvalue l1, Exp e)
{
  Lvalue l = checked_malloc(sizeof(*l));
  l->line = line;
  l->kind = ArrayDeref;
  l->u.arrayderef.l = l1;
  l->u.arrayderef.e = e;
  return l;
}

Lvalue mk_RecordDerefLvalue(int line, Lvalue l1, id i)
{
  Lvalue l = checked_malloc(sizeof(*l));
  l->line = line;
  l->kind = RecordDeref;
  l->u.recordderef.l = l1;
  l->u.recordderef.i = i;
  return l;
}


/* printing functions for ast's */
		
void tabs(FILE *out, int t)
{
  int i;
  fprintf(out,"\n");
  for(i = 0; i < t; i++)
    fprintf(out,"  ");
}

/* forward declarations (many mutually recursive) */
void print_Program(FILE *out,Program p);
void print_Body(FILE *out, int tab, Body b);
void print_DecList(FILE *out, int tab, DecList dl);
void print_VarDecList(FILE *out, int tab, VarDecList vdl);
void print_TypeDecList(FILE *out, int tab, TypeDecList tdl);
void print_ProcDecList(FILE *out, int tab, ProcDecList pdl);
void print_StatList(FILE *out, int tab, StatList sl);
void print_Stat(FILE *out, int tab, Stat s);
void print_Type(FILE *out, int tab,Type t);
void print_ComponentList(FILE *out, int tab,ComponentList cl);
void print_FormalParamList(FILE *out, int tab,FormalParamList fpl);
void print_ExpList(FILE *out, int tab, ExpList el);
void print_Exp(FILE *out, int tab, Exp e);
void print_ArrayInitList(FILE *out, int tab, ArrayInitList ail);
void print_RecordInitList(FILE *out, int tab, RecordInitList ril);
void print_LvalueList(FILE *out, int tab, LvalueList ll);
void print_Lvalue(FILE *out, int tab, Lvalue l);

void print_Program(FILE *out,Program p)
{
  assert(p);
  print_Body(out,0,p->b);
  fprintf(out,"\n");
}

void print_Body(FILE *out, int tab, Body b)
{
  assert(b);
  tabs(out,tab+1);
  fprintf(out,"(BodyDef  %d ", b->line);
  print_DecList(out,tab+1,b->dl);
  print_Stat(out,tab+2,b->s);
  fprintf(out,")");
}
		
void print_DecList(FILE *out, int tab, DecList dl)
{
  fprintf(out,"(");
  while (dl) {
    tabs(out,tab+1);
    assert(dl->d);
    switch(dl->d->kind) {
    case VarDecs: 
      fprintf(out,"(VarDecs ");
      print_VarDecList(out,tab+1,dl->d->u.vardecs.vl);
      fprintf(out,")");
      break;
    case TypeDecs: 
      fprintf(out,"(TypeDecs ");
      print_TypeDecList(out,tab+1,dl->d->u.typedecs.tl);
      fprintf(out,")");
      break;
    case ProcDecs:
      fprintf(out,"(ProcDecs ");
      print_ProcDecList(out,tab+1,dl->d->u.procdecs.pl);
      fprintf(out,")");
      break;
    default: assert(0);
    };
    dl = dl->next;
  };
  fprintf(out,")");
}

void print_VarDecList(FILE *out, int tab, VarDecList vdl)
{	
  fprintf(out,"(");
  while (vdl) {
    assert(vdl->v);
    tabs(out,tab+1);
    fprintf(out, "(VarDec %d %s ", vdl->v->line, vdl->v->i);
    print_Type(out,tab+1,vdl->v->t);
    print_Exp(out,tab+1,vdl->v->e);
    fprintf(out,")");
    vdl = vdl->next;
  };
  fprintf(out,")");
}

		
void print_TypeDecList(FILE *out, int tab, TypeDecList tdl)
{	
  fprintf(out,"(");
  while (tdl) {
    assert(tdl->t);
    tabs(out,tab+1);
    fprintf(out,"(TypeDec %d %s ",tdl->t->line, tdl->t->i);
    print_Type(out,tab+1,tdl->t->t);
    fprintf(out,")");
    tdl = tdl->next;
  };
  fprintf(out,")");
}
		
void print_ProcDecList(FILE *out, int tab, ProcDecList pdl)
{	
  fprintf(out,"(");
  while (pdl) {
    assert(pdl->p);
    tabs(out,tab+1);
    fprintf(out, "(ProcDec %d %s ", pdl->p->line, pdl->p->i);
    print_FormalParamList(out,tab+1,pdl->p->fl);
    print_Type(out,tab+1,pdl->p->t);
    print_Body(out,tab+1,pdl->p->b);
    fprintf(out,")");
    pdl = pdl->next;
  };
  fprintf(out,")");
}
		
void print_StatList(FILE *out, int tab, StatList sl)
{	
  fprintf(out,"(");   
  while (sl) {
    print_Stat(out,tab+1,sl->s);
    sl = sl->next;
  }
  fprintf(out,")");   
}

void print_Stat(FILE *out, int tab, Stat s)
{	
  assert(s);
  tabs(out,tab);
  switch(s->kind) {
  case AssignSt: 
    fprintf(out,"(AssignSt %d ", s->line);
    print_Lvalue(out,tab,s->u.assignst.l);
    print_Exp(out,tab,s->u.assignst.e);
    fprintf(out,")");
    break;
  case CallSt:
    fprintf(out,"(CallSt %d %s ", s->line, s->u.callst.i);
    print_ExpList(out,tab,s->u.callst.el);
    fprintf(out,")");
    break;
  case ReadSt: 
    fprintf(out,"(ReadSt %d ", s->line);
    print_LvalueList(out,tab,s->u.readst.ll);
    fprintf(out,")");
    break;
  case WriteSt:
    fprintf(out,"(WriteSt %d ", s->line);
    print_ExpList(out,tab,s->u.writest.el);	
    fprintf(out,")");
    break;
  case IfSt:
    fprintf(out,"(IfSt %d ", s->line);
    print_Exp(out,tab,s->u.ifst.e);
    print_Stat(out,tab,s->u.ifst.s1);	
    print_Stat(out,tab,s->u.ifst.s2);	
    fprintf(out,")");
    break;
  case WhileSt:
    fprintf(out,"(WhileSt %d ", s->line);
    print_Exp(out,tab,s->u.whilest.e);
    print_Stat(out,tab,s->u.whilest.s);
    fprintf(out,")");
    break;
  case LoopSt:
    fprintf(out,"(LoopSt %d ", s->line);
    print_Stat(out,tab,s->u.loopst.s);
    fprintf(out,")");
    break;
  case ForSt:
    fprintf(out,"(ForSt %d %s ", s->line, s->u.forst.i);	
    print_Exp(out,tab,s->u.forst.e1);	
    print_Exp(out,tab,s->u.forst.e2);	
    print_Exp(out,tab,s->u.forst.e3);	
    print_Stat(out,tab,s->u.forst.s);	
    fprintf(out,")");
    break;
  case ExitSt:
    fprintf(out,"(ExitSt %d)", s->line);
    break;
  case RetSt:
    fprintf(out,"(RetSt %d ", s->line);
    if (s->u.retst.e) print_Exp(out,tab,s->u.retst.e);	
    fprintf(out,")");
    break;
  case SeqSt:
    fprintf(out,"(SeqSt ");
    print_StatList(out,tab,s->u.seqst.sl);
    fprintf(out,") ");
    break;
  default: assert(0);
  };
}
		
void print_Type(FILE *out, int tab,Type t)
{	
  if (t) {  
    switch(t->kind) {
    case NamedTyp: 
      fprintf(out,"(NamedTyp %d %s)", t->line, t->u.namedtyp.i);
      break;
    case ArrayTyp:
      fprintf(out,"(ArrayTyp %d ", t->line);
      print_Type(out,tab,t->u.arraytyp.t);
      fprintf(out,")");
      break;
    case RecordTyp:
      fprintf(out,"(RecordTyp %d ", t->line);
      print_ComponentList(out,tab,t->u.recordtyp.cl);
      fprintf(out,")");
      break;
    default:
      assert(0);
    }
  } else 
    fprintf(out,"(NoTyp)");
}
		
void print_ComponentList(FILE *out, int tab, ComponentList cl)
{
  fprintf(out,"(");
  while (cl) {
    assert(cl->c);
    tabs(out,tab+1);
    fprintf(out,"(Comp %d %s ", cl->c->line, cl->c->i);
    print_Type(out,tab,cl->c->t);
    fprintf(out,")");
    cl = cl->next;
  };
  fprintf(out,")");
}


void print_FormalParamList(FILE *out, int tab,FormalParamList fpl)
{	
  fprintf(out,"(");
  while (fpl) {
    assert(fpl->f);
    tabs(out,tab+1);
    fprintf(out,"(Param %d %s ", fpl->f->line, fpl->f->i);
    print_Type(out,tab,fpl->f->t);
    fprintf(out,")");
    fpl = fpl->next;
  };
  fprintf(out,")");
}
		
void print_ExpList(FILE *out, int tab, ExpList el)
{	
  fprintf(out,"(");
  while (el) {
    assert(el->e);
    print_Exp(out,tab,el->e);
    el = el->next;
  };
  fprintf(out,")");
}

char *binop_name[] = {"GT","LT","EQ","GE","LE","NE","PLUS","MINUS","TIMES","SLASH",
		      "DIV","MOD","AND","OR"};

char *unop_name[] = {"UPLUS","UMINUS","NOT"};

void print_Exp(FILE *out, int tab, Exp e)
{
  assert(e);
  switch (e->kind) {
  case BinOpExp:
    tabs(out,tab+1);
    fprintf(out,"(BinOpExp %d ",e->line);
    assert(e->u.binopexp.b >= 0 && 
	   e->u.binopexp.b < (sizeof(binop_name)/sizeof(binop_name[0])));
    fprintf(out, " %s ",binop_name[e->u.binopexp.b]);
    print_Exp(out,tab+1,e->u.binopexp.e1);
    print_Exp(out,tab+1,e->u.binopexp.e2);
    fprintf(out,")");
    break;
  case UnOpExp: 
    tabs(out,tab+1);
    fprintf(out,"(UnOpExp %d ",e->line);
    assert(e->u.unopexp.u >= 0 && 
	   e->u.unopexp.u < (sizeof(unop_name)/sizeof(unop_name[0])));
    fprintf(out, " %s ",unop_name[e->u.unopexp.u]);
    print_Exp(out,tab+1,e->u.unopexp.e);
    fprintf(out,")");
    break;
  case LvalExp:
    tabs(out,tab+1);
    fprintf(out,"(LvalExp ");
    print_Lvalue(out,tab+1,e->u.lvalexp.l);
    fprintf(out,")");
    break;
  case CallExp:
    tabs(out,tab+1);
    fprintf(out,"(CallExp %d %s ",e->line, e->u.callexp.i);
    print_ExpList(out,tab+1,e->u.callexp.el);	
    fprintf(out,")");
    break;
  case RecordExp:
    tabs(out,tab+1);
    fprintf(out,"(RecordExp %d %s ",e->line, e->u.recordexp.i);
    print_RecordInitList(out,tab+1,e->u.recordexp.rl);
    fprintf(out,")");
    break;
  case ArrayExp:
    tabs(out,tab+1);
    fprintf(out,"(ArrayExp %d %s ",e->line, e->u.arrayexp.i);
    print_ArrayInitList(out,tab+1,e->u.arrayexp.al);
    fprintf(out,")");
    break;
  case IntConst:
    fprintf(out,"(IntConst %d %d)",e->line,e->u.intconstexp.i);
    break;
  case RealConst:
    fprintf(out,"(RealConst %d \"%s\")",e->line,e->u.realconstexp.r);
    break;
  case StringConst:
    fprintf(out,"(StringConst %d \"%s\")",e->line, e->u.stringconstexp.c);
    break;
  default: assert(0);
  };
}
		
void print_RecordInitList(FILE *out, int tab, RecordInitList ril)
{	
  fprintf(out,"(");
  while (ril) {
    assert(ril->r);
    tabs(out,tab+1);
    fprintf(out,"(RecordInit %s ",ril->r->i);
    print_Exp(out,tab+1,ril->r->e);
    fprintf(out,") ");
    ril = ril->next;
  }
  fprintf(out,")");
}


void print_ArrayInitList(FILE *out, int tab, ArrayInitList ail)
{	
  fprintf(out,"(");
  while (ail) {
    assert(ail->a);
    tabs(out,tab+1);
    fprintf(out,"(ArrayInit ");
    print_Exp(out,tab+1,ail->a->e1);
    print_Exp(out,tab+1,ail->a->e2);
    fprintf(out,") ");
    ail = ail->next;
  }
  fprintf(out,")");
}
		
void print_LvalueList(FILE *out, int tab, LvalueList ll)
{
  fprintf(out,"(");
  while (ll) {
    print_Lvalue(out,tab,ll->l);
    ll = ll->next;
  };
  fprintf(out,")");
}

void print_Lvalue(FILE *out, int tab, Lvalue l)
{	
  assert(l);
  switch(l->kind) {
  case Var:
    fprintf(out,"(Var %d %s)", l->line, l->u.var.i);
    break;
  case ArrayDeref:
    fprintf(out,"(ArrayDeref %d ", l->line);
    print_Lvalue(out,tab,l->u.arrayderef.l);
    print_Exp(out,tab,l->u.arrayderef.e);
    fprintf(out,") ");
    break;
  case RecordDeref:
    fprintf(out,"(RecordDeref %d ", l->line);
    print_Lvalue(out,tab,l->u.recordderef.l);
    fprintf(out," %s ) ",l->u.recordderef.i);
    break;
  default: assert(0);
  };
}

char *strsave(char *s) {
  char *r = checked_malloc(strlen(s) + 1);
  strcpy(r,s);
  return r;
}
