/* Very simple Sparc JIT for JVM subset. */
/* Warning: This code is not ANSI C; it uses many gcc-specific extensions. */

#include <stdio.h>
#include <assert.h>
#include "basics.h"
#include "class.h"
#include "bytecode.h"
#include "param_size.h"
#include "stack_size_change.h"
#include "sparc.h"

/* Support routines for generated sparc code. */

/* Exception printer */
#define NULL_PTR_EXN 0
#define INDEX_BOUNDS_EXN 1
#define ARRAY_SIZE_EXN 2
void exception(int reason) {
  printf("Uncaught Exception: ");
  switch (reason) {
  case NULL_PTR_EXN: 
    {
      printf("NullPointer");
      break;
    }
  case INDEX_BOUNDS_EXN:
    {
      printf("ArrayIndexOutOfBounds");
      break;
    }
  case ARRAY_SIZE_EXN:
    {
      printf("NegativeArraySize");
      break;
    }
  }
  printf("\n");
  exit(1);
}

void print_int(int i) {
  printf("%d",i);
}

void print_string(char *s) {
  printf("%s",s);
}

/* Array allocation */
u4 *alloc_array(int size) {
  if (size < 0)
    exception(ARRAY_SIZE_EXN);
  else {
    int i;
    u4 *a = alloc((size+1) * sizeof(u4));
    a[0] = size;
    for (i = 1; i <= size; i++)
      a[i] = 0;
    return &(a[1]);
  }
}
    
/* ---------------------------------------------------------- */

/* Globals */

/* The class we're interpreting/jitting. */
static Class *this_class;

/* ---------------------------------------------------------- */

/* Compute stack size at each point in a method. 
   Results go in size_at, which is assumed to be of length method->code_length. */
void compute_sizes(Method *method, int *size_at) {
  int i;

  // initially, mark size unknown at every point
  for (i = 0; i < method->code_length; i++)
    size_at[i] = -1;
  
  // initialize traces 
  typedef struct {
    int start_pc;
    int start_size;
  } Trace;
  int trace_count = 0;
  Trace traces[method->code_length]; // length is a ridiculous over-estimate

  // process all traces
  traces[trace_count++] = (Trace) {0,0};
  while (trace_count > 0) {
    Trace trace = traces[--trace_count];
    int pc = trace.start_pc;
    int size = trace.start_size;
    while (size_at[pc] == -1) {  // if we haven't yet assigned a size at this pc
      size_at[pc] = size;
      u1 opcode = method->code[pc];

      // compute change in stack size
      switch (opcode) {
      case INVOKESTATIC:   
	{
	  u2 index = (method->code[pc+1] << 8) | method->code[pc+2];
	  char *method_class = get_method_class(this_class,index);
	  /* only support static calls into this class */
	  if ((strcmp(method_class,get_class(this_class,this_class->class_index)) == 0)) {
	    Method *m = find_method(this_class,
				    get_method_name(this_class,index),
				    get_method_type(this_class,index));
	    u2 nargs = count_args(get_utf8(this_class,m->descriptor_index));
	    u2 nresult = count_result(get_utf8(this_class,m->descriptor_index));
	    size = size - nargs + nresult;
	  } else
	    die ("unimplemented INVOKESTATIC to method in different class %s", method_class);
	  break;
	}
      default:
	assert(stack_size_change[opcode] != UNDEFINED_SIZE_CHANGE);
	size += stack_size_change[opcode]; 
	break;
      }

      // compute next pc to analyze
      switch (opcode) {
      case GOTO:
	{
	  int16 offset = (method->code[pc+1] << 8) | method->code[pc+2];
	  pc += offset;  // follow the jump
	  break;
	}
      case GOTO_W:
	{
	  int32 offset = (method->code[pc+1] << 24) | (method->code[pc+2] << 16) | 
	    (method->code[pc+3] << 8) | (method->code[pc+4]);
	  pc += offset; // follow the jump
	  break;
	}
      case IFNULL:
      case IFEQ:
      case IFGE:
      case IFGT:
      case IFLE:
      case IFLT:
      case IFNONNULL:
      case IFNE:
      case IF_ACMPEQ:
      case IF_ICMPEQ:
      case IF_ICMPGE:
      case IF_ICMPGT:
      case IF_ICMPLE:
      case IF_ICMPLT:
      case IF_ACMPNE:
      case IF_ICMPNE:
	{
	  int16 offset = (method->code[pc+1] << 8) | method->code[pc+2];
	  // add a new trace beginning at jump target
	  int new_pc = pc+offset;
	  traces[trace_count++] = (Trace) {new_pc,size}; 
	  // meanwhile, continue tracing where we are
	  // fall through to next case!
	}
      default:
	assert(param_size[opcode] != UNDEFINED_PARAM_SIZE);
	pc += 1 + param_size[opcode];
      }
    }
  }
}

/* ---------------------------------------------------------- */


/* Compile a method */
u4 *compile(Method *method) {
  
  /* Very simple register allocation scheme:
     just map each stack slot and each local variable to a unique register. */
  static reg stackreg[] = {L0,L1,L2,L3,L4,L5,L6,L7};
  static reg varreg[] = {I0,I1,I2,I3,I4,I5};
  static reg argreg[] = {O0,O1,O2,O3,O4,O5};  // for outgoing arguments

  // check some limits
  if (method->max_locals > (sizeof(varreg)/sizeof(reg)))
    die("insufficient variable registers");
  if (method->max_stack > (sizeof(stackreg)/sizeof(reg)))
    die("insufficient stack registers");

  // initialize stack size information
  int size_at[method->code_length];   
  compute_sizes(method,size_at);

  // initialize sparc code buffer
  int max_insts = method->code_length*4;  // inflation factor is just a guess (we could scan the byte code to compute this)
  u4 *scode = alloc(sizeof(u4) * max_insts);  // the sparc code buffer
  method->scode = scode;                  // store this now to avoid infinite recursion while compiling!
  int spc = 0;                            // last sparc code buffer offset filled
#define EMIT(inst) ({if (spc >= max_insts) die("insufficient code space"); else scode[spc++] = inst;})
#define GEN_RELCALL(addr) (gen_call((((u4)addr) - ((u4)(scode+spc)))>>2));

  // initialize backpatching structures
  int scode_at[method->code_length];      // sparc code offsets corresponding to each bytecode addr
  typedef struct {
    int loc;          // offset in sparc code requiring backpatching
    int target;       // corresponding bytecode addr of branch target 
  } Backpatch;
  Backpatch backpatches[max_insts];      // the backpatching list (length is a ridiculous over-estimate)
  int backpatch_count = 0;               // number of backpatches

  // emit header
  EMIT(gen_op_imm(SAVE_OP,O6,O6,-96));

  // emit code for each instruction
  int pc;
  for (pc = 0; pc < method->code_length; ) {
    u1 opcode = method->code[pc];
    int size = size_at[pc];
    scode_at[pc] = spc;  
    switch (opcode) {
    case ARRAYLENGTH:
      {
	EMIT(GEN_TST(stackreg[size-1]));     // test for
	EMIT(gen_bicc(NE_COND,0,4));         // null pointer
	EMIT(GEN_NOP);
	EMIT(GEN_RELCALL(&exception));
	EMIT(GEN_MOV_IMM(O0,NULL_PTR_EXN));  // delay slot: pass exception kind
	EMIT(gen_mop_imm(LD_OP,stackreg[size-1],stackreg[size-1],-4));  // fetch length
	break;
      }
    case BIPUSH:
      {	
	int c = (int8)(method->code[pc+1]);
	EMIT(GEN_MOV_IMM(stackreg[size],c));
	break; 
      }
    case DUP: 
      {	
	EMIT(GEN_MOV(stackreg[size],stackreg[size-1]));
	break;
      }
    case DUP2: 
      {	
	EMIT(GEN_MOV(stackreg[size+1],stackreg[size-1]));
	EMIT(GEN_MOV(stackreg[size],stackreg[size-2]));
	break;
      }
    case GETSTATIC:
      { 
	u2 index = (method->code[pc+1] << 8) | method->code[pc+2];
	/* kludge for java.lang.System.out */
	if ((strcmp (get_field_class(this_class,index),"java/lang/System") == 0) &&
	    (strcmp (get_field_name(this_class,index),"out") == 0)) 
	  ; /* just "push" an empty slot */
	else 
	  die("unimplemented instruction GETSTATIC");
	break;
      }
    case GOTO:
      {
	int16 offset = (method->code[pc+1] << 8) | method->code[pc+2];
	int new_pc = pc+offset;
	backpatches[backpatch_count++] = (Backpatch) {spc,new_pc};
	EMIT(gen_bicc(A_COND,0,0));
	EMIT(GEN_NOP);
	break;
      }
    case GOTO_W:
      { 
	int32 offset = (method->code[pc+1] << 24) | (method->code[pc+2] << 16) | 
	  (method->code[pc+3] << 8) | (method->code[pc+4]);
	int new_pc = pc+offset;
	backpatches[backpatch_count++] = (Backpatch) {spc,new_pc};
	EMIT(gen_bicc(A_COND,0,0));
	EMIT(GEN_NOP);
	break;
      }

#define IBINOP(op) \
      { \
        EMIT(gen_op(op,stackreg[size-2],stackreg[size-2],stackreg[size-1]));\
	break;\
      }

    case IADD:
      IBINOP(ADD_OP);
    case IALOAD:
      {
	EMIT(GEN_TST(stackreg[size-2]));   // test for 
	EMIT(gen_bicc(NE_COND,0,4));        // null pointer
	EMIT(GEN_NOP);
	EMIT(GEN_RELCALL(&exception));
	EMIT(GEN_MOV_IMM(O0,NULL_PTR_EXN));     // delay slot:pass exception kind
	EMIT(gen_mop_imm(LD_OP,G1,stackreg[size-2],-4)); // load length 
	EMIT(GEN_CMP(stackreg[size-1],G1)); // compare index vs. length
	EMIT(gen_bicc(LEU_COND,0,4));        // note unsigned test
	EMIT(gen_op_imm(SLL_OP,stackreg[size-1],stackreg[size-1],2));  // in delay slot: scale index to bytes
	EMIT(GEN_RELCALL(&exception));
	EMIT(GEN_MOV_IMM(O0,INDEX_BOUNDS_EXN));  // delay slot:pass exception kind
	EMIT(gen_mop(LD_OP,stackreg[size-2],stackreg[size-2],stackreg[size-1]));  // fetch element
	break;
      }
    case IAND:
      IBINOP(AND_OP);
    case IASTORE:
      {
	EMIT(GEN_TST(stackreg[size-3]));    // test for
	EMIT(gen_bicc(NE_COND,0,4));         // null pointer
	EMIT(GEN_NOP);
	EMIT(GEN_RELCALL(&exception));
	EMIT(GEN_MOV_IMM(O0,NULL_PTR_EXN));     // delay slot:pass exception kind
	EMIT(gen_mop_imm(LD_OP,G1,stackreg[size-3],-4)); // load length
	EMIT(GEN_CMP(stackreg[size-2],G1)); // compare index vs. length
	EMIT(gen_bicc(LEU_COND,0,4));        // note unsigned test
	EMIT(gen_op_imm(SLL_OP,stackreg[size-2],stackreg[size-2],2));  // in delay slot: scale index to bytes
	EMIT(GEN_RELCALL(&exception));
	EMIT(GEN_MOV_IMM(O0,INDEX_BOUNDS_EXN));  // delay slot:pass exception kind
	EMIT(gen_mop(ST_OP,stackreg[size-1],stackreg[size-3],stackreg[size-2])); // store element
	break;
      }
#define ICONST(v)\
      { EMIT(GEN_MOV_IMM(stackreg[size],v));\
         break;\
      }

    case ICONST_M1:
      ICONST(-1);
    case ACONST_NULL:
    case ICONST_0:
      ICONST(0);
    case ICONST_1:
      ICONST(1);
    case ICONST_2:
      ICONST(2);
    case ICONST_3:
      ICONST(3);
    case ICONST_4:
      ICONST(4);
    case ICONST_5:
      ICONST(5);
    case IDIV:
      EMIT(gen_op(WRY_OP,0,G0,G0)); // set %y = 0 before division
      EMIT(GEN_NOP);  // timing rules say we should wait for 
      EMIT(GEN_NOP);  // two cycles for %y to clear
      IBINOP(SDIV_OP);

#define IF_ICMP(cond) \
      { \
        EMIT(GEN_CMP(stackreg[size-2],stackreg[size-1])); \
	int16 offset = (method->code[pc+1] << 8) | method->code[pc+2]; \
	int new_pc = pc+offset; \
	backpatches[backpatch_count++] = (Backpatch) {spc,new_pc};\
        EMIT(gen_bicc(cond,0,0)); \
        EMIT(GEN_NOP); \
        break; \
} 
	

    case IF_ACMPEQ:
    case IF_ICMPEQ:
      IF_ICMP(E_COND);
    case IF_ICMPGE:
      IF_ICMP(GE_COND);
    case IF_ICMPGT:
      IF_ICMP(G_COND);
    case IF_ICMPLE:
      IF_ICMP(LE_COND);
    case IF_ICMPLT:
      IF_ICMP(L_COND);
    case IF_ACMPNE:
    case IF_ICMPNE:
      IF_ICMP(NE_COND);

#define IF0(cond) \
      {\
	EMIT(GEN_TST(stackreg[size-1])); \
	int16 offset = (method->code[pc+1] << 8) | method->code[pc+2];\
	int new_pc = pc+offset; \
	backpatches[backpatch_count++] = (Backpatch) {spc,new_pc};\
        EMIT(gen_bicc(cond,0,0)); \
        EMIT(GEN_NOP); \
        break; \
} 

    case IFNULL:
    case IFEQ:
      IF0(E_COND);
    case IFGE:
      IF0(GE_COND);
    case IFGT:
      IF0(G_COND);
    case IFLE:
      IF0(LE_COND);
    case IFLT:
      IF0(L_COND);
    case IFNONNULL:
    case IFNE:
      IF0(NE_COND);

    case IINC:
      {
	u1 index = method->code[pc+1];
	int8 incr = (int8) (method->code[pc+2]);
	EMIT(gen_op_imm(ADD_OP,varreg[index],varreg[index],incr));
	break;
      }
    case ALOAD:
    case ILOAD:
      {
	u1 index = method->code[pc+1];
	EMIT(GEN_MOV(stackreg[size],varreg[index]));
	break;
      }

#define ILOADL(index)\
      {\
        EMIT(GEN_MOV(stackreg[size],varreg[index]));\
	break;\
      }

    case ALOAD_0:
    case ILOAD_0:
      ILOADL(0);

    case ALOAD_1:
    case ILOAD_1:
      ILOADL(1);

    case ALOAD_2:
    case ILOAD_2:
      ILOADL(2);

    case ALOAD_3:
    case ILOAD_3:
      ILOADL(3);

    case IMUL:
      IBINOP(SMUL_OP);

    case INEG:
      {
	EMIT(gen_op(SUB_OP,stackreg[size-1],G0,stackreg[size-1]));
	break;
      }
    case INVOKESTATIC:   
      {
	u2 index = (method->code[pc+1] << 8) | method->code[pc+2];
	char *method_class = get_method_class(this_class,index);
	/* only support static calls into this class */
	if ((strcmp(method_class,get_class(this_class,this_class->class_index)) == 0)) {
	  Method *m = find_method(this_class,
				  get_method_name(this_class,index),
				  get_method_type(this_class,index));
	      
	  /* compile method if necessary */
	  if (!m->scode) 
	    compile(m);
	  u2 nargs = count_args(get_utf8(this_class,m->descriptor_index));
	  /* move stack regs corresponding to arguments into arg regs */
	  if (nargs > (sizeof(argreg)/sizeof(reg)))
	    die("insufficient argument registers");
	  int i;
	  for (i = 0; i < nargs; i++)
	    EMIT(GEN_MOV(argreg[i],stackreg[size-nargs+i]));
	  EMIT(GEN_RELCALL(m->scode));
	  EMIT(GEN_NOP);
	  if (count_result(get_utf8(this_class,m->descriptor_index)) == 1) {
	    /* move result to new top-of-stack reg */
	    EMIT(GEN_MOV(stackreg[size-nargs],O0));
	  }
	} else
	  die ("unimplemented INVOKESTATIC to method in different class %s", method_class);
	break;
      }
    case INVOKEVIRTUAL:
      { 
	u2 index = (method->code[pc+1] << 8) | method->code[pc+2];
	/* kludge for java.io.PrintStream.print */
	if ((strcmp(get_method_class(this_class,index),"java/io/PrintStream") == 0)
	    && (strcmp(get_method_name(this_class,index),"print") == 0)) {
	  char *type = get_method_type(this_class,index);
	  if (strcmp(type,"(Ljava/lang/String;)V") == 0) {
	    /* arg is string */
	    EMIT(GEN_RELCALL(&print_string));
	  } else if (strcmp(type,"(I)V") == 0) {
	    /* arg is int */
	    EMIT(GEN_RELCALL(&print_int));
	  } else 
	    die ("unimplemented PrintStream.print method with signature %s", type);
	  EMIT(GEN_MOV(O0,stackreg[size-1])); // in delay slot: pass value
	} else 
	  die ("unimplemented instruction INVOKEVIRTUAL");
	break;
      }
    case IOR:
      IBINOP(OR_OP);
    case IREM:
      {
	// no simple instruction to do this (%y may not be properly set by SDIV)
	EMIT(gen_op(WRY_OP,0,G0,G0)); // set %y = 0 before division
	EMIT(GEN_NOP);  // timing rules say we should wait for 
	EMIT(GEN_NOP);  // two cycles for %y to clear
	EMIT(gen_op(SDIV_OP,G1,stackreg[size-2],stackreg[size-1]));
	EMIT(gen_op(SMUL_OP,G1,G1,stackreg[size-1]));
	EMIT(gen_op(SUB_OP,stackreg[size-2],stackreg[size-2],G1));
	break;
      }
    case ARETURN:
    case IRETURN:  
      { 
	EMIT(GEN_MOV(I0,stackreg[size-1]));  // move return value to C return reg
	EMIT(gen_op_imm(JMPL_OP,G0,I7,8));    // return
	EMIT(gen_op(RESTORE_OP,G0,G0,G0));           // delay slot: restore register window
	break;
      }
    case ASTORE:
    case ISTORE:
      {
	u1 index = method->code[pc+1];
	EMIT(GEN_MOV(varreg[index],stackreg[size-1]));
	break;
      }
#define ISTOREL(index)\
      {\
        EMIT(GEN_MOV(varreg[index],stackreg[size-1])); \
        break;\
      }

    case ASTORE_0:
    case ISTORE_0:
      ISTOREL(0);
    case ASTORE_1:
    case ISTORE_1:
      ISTOREL(1);
    case ASTORE_2:
    case ISTORE_2:
      ISTOREL(2);
    case ASTORE_3:
    case ISTORE_3:
      ISTOREL(3);
    case ISUB:
      IBINOP(SUB_OP);
    case IXOR:
      IBINOP(XOR_OP);
    case LDC: 
      {
	u1 index = method->code[pc+1];
	u1 tag = get_constant_tag(this_class,index);
	u4 c;
	if (tag == CONSTANT_String) 
	  c = (u4) (get_string(this_class,index));
	else if (tag == CONSTANT_Integer) 
	  c = (u4) (get_integer(this_class,index));
	else 
	  die ("unsupported constant type %d", tag);
	EMIT(gen_sethi(stackreg[size],HIBITS(c)));                         // get value
	EMIT(gen_op_imm(OR_OP,stackreg[size],stackreg[size],LOBITS(c)));  // (assume it needs sethi!)
	break;
      }
    case LDC_W: 
      {
	u2 index = (method->code[pc+1] << 8) | method->code[pc+2];
	u1 tag = get_constant_tag(this_class,index);
	u4 c;
	if (tag == CONSTANT_String) 
	  c = (u4) (get_string(this_class,index));
	else if (tag == CONSTANT_Integer) 
	  c = (u4) (get_integer(this_class,index));
	else 
	  die ("unsupported constant type %d", tag);
	EMIT(gen_sethi(stackreg[size],HIBITS(c)));
	EMIT(gen_op_imm(OR_OP,stackreg[size],stackreg[size],LOBITS(c)));
	break;
      }
    case NEWARRAY: 
      {
	u1 atype = method->code[pc+1];
	if (atype != 10) 
	  die("unimplemented array type %d",atype);
	EMIT(GEN_RELCALL(&alloc_array));     // all the interesting work is done in alloc_array
	EMIT(GEN_MOV(O0,stackreg[size-1])); // in delay slot: pass length
	EMIT(GEN_MOV(stackreg[size-1],O0)); // retrieve result
	break;
      }
    case NOP: 
    case POP:
    case POP2:
      {
	break;
      }
    case RETURN:  
      {
	EMIT(gen_op_imm(JMPL_OP,G0,I7,8));  // return
	EMIT(gen_op(RESTORE_OP,G0,G0,G0));  // delay slot: restore register window
	break;
      }
    case SIPUSH:
      {
	int c = (int16) ((method->code[pc+1] << 8) | method->code[pc+2]);
	if (c >= -4096 && c <= 4095) 
	  EMIT(GEN_MOV_IMM(stackreg[size],c));
	else {
	  EMIT(gen_sethi(stackreg[size],HIBITS(c)));
	  EMIT(gen_op_imm(OR_OP,stackreg[size],stackreg[size],LOBITS(c)));
	}
	break; 
      }
    case SWAP:
      {
	EMIT(GEN_MOV(G1,stackreg[size-1]));
	EMIT(GEN_MOV(stackreg[size-1],stackreg[size-2]));
	EMIT(GEN_MOV(stackreg[size-2],G1));
	break;
      }
    default: 
      die("unimplemented instruction code %i", opcode);
    }

    assert (param_size[opcode] != UNDEFINED_PARAM_SIZE);
    pc += 1 + param_size[opcode];

  }

  // perform backpatching
  while (backpatch_count-- > 0) 
    scode[backpatches[backpatch_count].loc] 
      |= ((scode_at[backpatches[backpatch_count].target] - backpatches[backpatch_count].loc) & 0x3fffff);
  
  // dump code we just generated
  fprintf(stderr,"%s:\n", get_utf8(this_class,method->name_index));
  dump(scode,spc);
  
  return scode;
}

/* ---------------------------------------------------------- */

int main(int argc, char **argv) {
  char file[255];
  FILE *cf;
  Method *main_method;
  strcpy(file,*++argv);
  strcat(file,".class");
  cf = fopen(file,"r");
  this_class = read_class(cf);
  fclose(cf);
  main_method = find_method(this_class,"main","([Ljava/lang/String;)V");
  u4 *scode = compile(main_method);
  typedef void (*func)(int x); // just defined to ease to cast below
  // ((func)scode)(0         );  // call into main method
  return 0;
}
