import java.io.*;
import java.util.*;

// Liveness analysis

class Liveness {

  static class LivenessException extends Exception {
    LivenessException(String text) {
      super(text);
    }
  }

  // Utility class for describing sets of register temporary operands.
  static class RegTempSet {
    HashSet set;  // contains RegTemp
    RegTempSet() {
      set = new HashSet();
    }
    void add(IR.Operand op) {
      // only add operands that are RegTemp's
      if (op instanceof IR.RegTemp)
	set.add(op);   
    }
    public boolean equals(Object os) {
      return (os instanceof RegTempSet) && 
	(((RegTempSet) os).set.equals(set));  // expensive!
    }
    void diff(RegTempSet os) {
      for (Iterator it = os.set.iterator(); it.hasNext(); ) 
	set.remove(it.next());
    }
    void union(RegTempSet os) {
      for (Iterator it = os.set.iterator(); it.hasNext();)
	set.add(it.next());
    }	
    RegTempSet copy() {
      RegTempSet s = new RegTempSet();
      s.set = (HashSet) set.clone();
      return s;
    }
    public String toString() {
      String r = "{ ";
      for (Iterator it = set.iterator(); it.hasNext();)
	r += it.next() + " ";
      r += "}";
      return r;
    }
    Iterator iterator() {
      return set.iterator();
    }
  }

  // Utility class for describing lists of integers
  static class IndexList {
    List list;
    IndexList() {
      list = new ArrayList();
    }
    void add(int i) {
      list.add(new Integer(i));
    }
    int get(int p) {
      return ((Integer)list.get(p)).intValue();
    }
    int size() {
      return list.size();
    }
    public String toString() {
      String r = "[";
      if (size() > 0) {
	r += get(0);
	for (int i = 1; i < size(); i++) 
	  r += "," + get(i);
      }
      r += "]";
      return r;
    }
  }

  // Calculate successor information for each instruction in a body
  static IndexList[] calculateSuccessors (IR.Body body) throws LivenessException {
    IndexList[] allSuccs = new IndexList[body.codeLines.length];
    int length = body.codeLines.length;
    for (int i = 0; i < length; i++) {
      IR.CodeLine cl = body.codeLines[i];
      IndexList succs = new IndexList();
      if (cl instanceof IR.Inst && cl != null) {
	IR.Inst inst = (IR.Inst) cl;
	switch (inst.op) {
	case IR.RET_OP:  // return: no successors
	  break;
	case IR.BA_OP: // unconditional branch
	  {
	    IR.Label target = (IR.Label) inst.rands[0];
	    succs.add(body.labels[target.i]);
	    break;
	  }
	case IR.BG_OP: // conditional branches
	case IR.BL_OP: 
	case IR.BE_OP: 
	case IR.BGE_OP:
	case IR.BLE_OP:
	case IR.BNE_OP: 
	case IR.BGU_OP:
	case IR.BLU_OP:
	case IR.BGEU_OP:
	case IR.BLEU_OP:
	  {
	    IR.Label target = (IR.Label) inst.rands[0];
	    succs.add(body.labels[target.i]);
	    if (i+1 >= length)
	      throw new LivenessException("Body falls off end");
	    succs.add(i+1); 
	    break;
	  }
	default:
	  {
	    if (i+1 >= length)
	      throw new LivenessException("Body falls off end");
	    succs.add(i+1);  
	    break;
	  }
	}
      } else {
	if (i+1 >= length)
	  throw new LivenessException("Body falls off end");
	succs.add(i+1);
      }
      allSuccs[i] = succs;
    }
    return allSuccs;
  }


  // Calculate liveOut information for each instruction in a body.
  static RegTempSet[] calculateLiveness (IR.Body body) throws LivenessException {
    IndexList[] allSuccs = calculateSuccessors(body);

    // Calculate sets of temporaries used and defined by each Inst
    RegTempSet[] used = new RegTempSet[body.codeLines.length];
    RegTempSet[] defined = new RegTempSet[body.codeLines.length];
    for (int i = 0; i < body.codeLines.length; i++) {
      used[i] = new RegTempSet();
      defined[i] = new RegTempSet();
      IR.CodeLine cl = body.codeLines[i];
      if (cl instanceof IR.Inst && cl != null) {
	IR.Inst inst = (IR.Inst) cl;
	switch (inst.op) {
	case IR.LD_OP: // use op0; define op1
	case IR.MOV_OP:
	  used[i].add(inst.rands[0]);
	  defined[i].add(inst.rands[1]);
	  break;
	case IR.ST_OP: // use op0, op1
	case IR.CMP_OP:
	  used[i].add(inst.rands[0]);
	  used[i].add(inst.rands[1]);
	  break;
	case IR.ADD_OP: // use op0, op1; define op2
	case IR.SUB_OP:
	case IR.SMUL_OP:
	case IR.SDIV_OP:
	case IR.ADDA_OP:
	  used[i].add(inst.rands[0]);
	  used[i].add(inst.rands[1]);
	  defined[i].add(inst.rands[2]);
	  break;
	case IR.SCALL_OP: // use op1
	  used[i].add(inst.rands[1]);
	  break;
	default:
	  break;
	}	
      }
    }

    // DEBUG
    // for (int i = 0; i < body.codeLines.length; i++) 
    // System.out.println("" + i + "\t" + "U:" + used[i] + "\t" + "D:" + defined[i]);

    // Now solve dataflow equations to calculate
    // set of temporaries that are live out of each Inst
    RegTempSet[] liveIn = new RegTempSet[body.codeLines.length];
    RegTempSet[] liveOut = new RegTempSet[body.codeLines.length];
    for (int i = 0; i < body.codeLines.length; i++) {
      liveIn[i] = new RegTempSet();
      liveOut[i] = new RegTempSet();
    }
    
    boolean changed = true;
    while (changed) {
      changed = false;
      for (int i = body.codeLines.length-1; i >= 0; i--) {
	RegTempSet newLiveIn = liveOut[i].copy();
	newLiveIn.diff(defined[i]);
	newLiveIn.union(used[i]);
	liveIn[i] = newLiveIn;
	RegTempSet newLiveOut = new RegTempSet();
	for (int n = 0; n < allSuccs[i].size(); n++) 
	  newLiveOut.union(liveIn[allSuccs[i].get(n)]);
	if (!liveOut[i].equals(newLiveOut)) {
	  liveOut[i] = newLiveOut;
	  changed = true;
	}
      }
    }
    return liveOut;
  }

  static class Interval {
    int start;
    int end;
    Interval (int start, int end) {
      this.start = start; this.end = end;
    }
  }

  // calculate live interval for each temporary in body
  static Map calculateLiveIntervals(IR.Body body) throws LivenessException {
    Map liveIntervals = new HashMap();  // keys are RegTemp; values are Interval
    RegTempSet liveOut[] = calculateLiveness(body);
    for (int i = 0; i < body.codeLines.length; i++) {
      for (Iterator it = liveOut[i].iterator(); it.hasNext();) {
	IR.RegTemp t = (IR.RegTemp) (it.next());
	Interval n = (Interval) (liveIntervals.get(t));
	if (n == null) {
	  n = new Interval(i,i);
	  liveIntervals.put(t,n);
	} else
	  n.end = i;
      }
    }
    // DEBUG
    // Set lis = liveIntervals.entrySet();
    // for (Iterator it = lis.iterator(); it.hasNext();) {
    //    Map.Entry me = (Map.Entry) (it.next());
    //    IR.RegTemp t = (IR.RegTemp) (me.getKey());
    //    Liveness.Interval n = (Liveness.Interval) (me.getValue());
    //    System.out.println("" + t + "\t[" + n.start + "," + n.end + "]");
    // }
    return liveIntervals;
  }

}


