/* Support for loading and querying class file data structure. 
   Only a subset of the class file is supported.
*/

#include <stdio.h>
#include <assert.h>
#include "basics.h"
#include "class.h"

/* Constant pool utilities. */
static Constant get_constant(Class *cl, u2 index) {
  assert (index >= 1 && index <= cl->constants_count - 1);
  return cl->constants[index-1];
}

u1 get_constant_tag(Class *cl, u2 index) {
  return get_constant(cl,index).tag;
}

char *get_utf8(Class *class,u2 index) {
  Constant c = get_constant(class,index);
  assert (c.tag == CONSTANT_Utf8);
  return (char *) (c.u.utf8.bytes);
}

char *get_class(Class *class,u2 index) {
  Constant c = get_constant(class,index);
  assert (c.tag == CONSTANT_Class);
  return get_utf8(class,c.u.class.name_index);
}

char *get_string(Class *class,u2 index) {
  Constant c = get_constant(class,index);
  assert (c.tag == CONSTANT_String);
  return get_utf8(class,c.u.string.string_index);
}

u4 get_integer(Class *class,u2 index) {
  Constant c = get_constant(class,index);
  assert (c.tag == CONSTANT_Integer);
  return c.u.integer.bytes;
}

char *get_name(Class *class,u2 index) {
  Constant c = get_constant(class,index);
  assert (c.tag == CONSTANT_NameAndType);
  return get_utf8(class,c.u.name_and_type.name_index);
}

char *get_type(Class *class,u2 index) {
  Constant c = get_constant(class,index);
  assert (c.tag == CONSTANT_NameAndType);
  return get_utf8(class,c.u.name_and_type.descriptor_index);
}

char *get_field_class(Class *class,u2 index) {
  Constant c = get_constant(class,index);
  assert (c.tag == CONSTANT_Fieldref);
  return get_class(class,c.u.fieldref.class_index);
}

char *get_field_name(Class *class,u2 index) {
  Constant c = get_constant(class,index);
  assert (c.tag == CONSTANT_Fieldref);
  return get_name(class,c.u.fieldref.name_and_type_index);
}

char *get_field_type(Class *class,u2 index) {
  Constant c = get_constant(class,index);
  assert (c.tag == CONSTANT_Fieldref);
  return get_type(class,c.u.fieldref.name_and_type_index);
}

char *get_method_class(Class *class,u2 index) {
  Constant c = get_constant(class,index);
  assert (c.tag == CONSTANT_Methodref);
  return get_class(class,c.u.methodref.class_index);
}

char *get_method_name(Class *class,u2 index) {
  Constant c = get_constant(class,index);
  assert (c.tag == CONSTANT_Methodref);
  return get_name(class,c.u.methodref.name_and_type_index);
}

char *get_method_type(Class *class,u2 index) {
  Constant c = get_constant(class,index);
  assert (c.tag == CONSTANT_Methodref);
  return get_type(class,c.u.methodref.name_and_type_index);
}

/* Method table utilities.*/

/* Find a method (by name and type) in the class method table.
   (Note that we can make no assumptions that a given
   name or type appears only once in the pool, so it is
   not safe to merely compare indices.)
 */
Method *find_method(Class *class,char *method_name, char *method_type) {
  int i = 0;
  Method *mp = class->methods;
  for (; i < class->methods_count; i++, mp++) {
    if ((strcmp(get_utf8(class,mp->name_index), method_name) == 0) &&
	(strcmp(get_utf8(class,mp->descriptor_index), method_type) == 0))
      return mp;
  }
  assert(false);  /* method was not found */
}
  
/* Descriptor utilities. */
/* Consume one parameter descriptor from a sequence of them. */
static char *consume_descriptor(char* s) {
  if (*s == '[')
    return consume_descriptor(s+1);
  else if (*s == 'L') {
    for (s++; *s != ';'; s++);
    return s+1;
  } else  /* just assume a valid BaseType code */
    return s+1;
}

/* Analyze a method type descriptor to determine its argument count (assuming all args take a word). */
int count_args(char *s) {
  int i = 0;
  assert (*s == '(');
  s++;
  while (*s != ')') {
    s = consume_descriptor(s);
    i++;
  }
  return i;
}

/* Analyze a method type descriptor to determine its result count (assuming any result takes a word). */
int count_result(char *s) {
  assert (*s == '(');
  s++;
  while (*s != ')') {
    s = consume_descriptor(s);
  }
  s++;
  return (*s != 'V');
}

/* Read class file into internal data structure. */

static void readu1(FILE *cf,u1 *u) {
  *u = getc(cf);
}

static void readu2(FILE *cf,u2 *u) {
  *u = getc(cf);
  *u = ((*u) << 8) | (getc(cf));
}

static void readu4(FILE *cf,u4 *u) {
  *u = getc(cf);
  *u = ((*u) << 8) | (getc(cf));
  *u = ((*u) << 8) | (getc(cf));
  *u = ((*u) << 8) | (getc(cf));
}

static void read_constant(FILE *cf,Constant *c) {
  int i;
  readu1(cf,&c->tag);
  switch (c->tag) {
  case CONSTANT_Class:
    readu2(cf,&c->u.class.name_index);
    break;
  case CONSTANT_Fieldref:
    readu2(cf,&c->u.fieldref.class_index);
    readu2(cf,&c->u.fieldref.name_and_type_index);
    break;
  case CONSTANT_Methodref:
    readu2(cf,&c->u.methodref.class_index);
    readu2(cf,&c->u.methodref.name_and_type_index);
    break;
  case CONSTANT_String:
    readu2(cf,&c->u.string.string_index);
    break;
  case CONSTANT_Integer:
    readu4(cf,&c->u.integer.bytes);
    break;
  case CONSTANT_NameAndType:
    readu2(cf,&c->u.name_and_type.name_index);
    readu2(cf,&c->u.name_and_type.descriptor_index);
    break;
  case CONSTANT_Utf8:
    readu2(cf,&c->u.utf8.length);
    c->u.utf8.bytes = (u1 *) (alloc(sizeof(u1) * (c->u.utf8.length + 1)));
    for (i = 0; i < c->u.utf8.length; i++)
      readu1(cf,c->u.utf8.bytes + i);
    c->u.utf8.bytes[i] = '\0';
    break;
  default:
    die("unsupported constant pool entry tag %d",c->tag);
  }
}

/* expects attribute name index has already been read */
static void skip_attribute(FILE *cf) {
  int i;
  u4 length;
  readu4(cf,&length);
  for (i = 0; i < length; i++) {
    u1 dummy1;
    readu1(cf,&dummy1);
  }
}

static void skip_attributes(FILE *cf) {
  int i;
  u2 dummy2;
  u2 count;
  readu2(cf,&count); /* attributes count */
  for (i = 0; i < count; i++) {
    readu2(cf,&dummy2); /* skip attribute name index */
    skip_attribute(cf);
  }
}

/* We need Class cl only for the constant pool.*/
static void read_method(FILE *cf, Class *cl,Method *m) {
  int i;
  u4 dummy4;
  u2 dummy2;
  u2 count;
  readu2(cf,&dummy2); /* skip access flags */
  readu2(cf,&m->name_index);
  readu2(cf,&m->descriptor_index);
  readu2(cf,&count);  /* attribute count */
  for (i = 0; i < count; i++) {
    u2 index;
    readu2(cf,&index);
    if (strcmp(get_utf8(cl,index),"Code") == 0) {
      int i;
      readu4(cf,&dummy4); /* skip attribute length */
      readu2(cf,&m->max_stack);
      readu2(cf,&m->max_locals);
      readu4(cf,&m->code_length);
      m->code = (u1 *) (alloc(sizeof(u1) * m->code_length));
      for (i = 0; i < m->code_length; i++) 
	readu1(cf,m->code+i);
      readu2(cf,&count); /* exception table length */
      for (i = 0; i < count; i++) {
	readu2(cf,&dummy2); /* skip start_pc */
	readu2(cf,&dummy2); /* skip end_pc */
	readu2(cf,&dummy2); /* skip handler_pc */
	readu2(cf,&dummy2); /* skip catch_type */
      }
      skip_attributes(cf);
    } else
      skip_attribute(cf);
  }
  m->scode = 0;
}

Class *read_class(FILE *cf) {
  int i;
  u4 dummy4;
  u2 dummy2;
  u2 count;
  Class *cl = alloc(sizeof(*cl)); 
  readu4(cf,&dummy4); /* skip magic number */
  readu2(cf,&dummy2); /* skip minor version */
  readu2(cf,&dummy2); /* skip major version */
  readu2(cf,&cl->constants_count);
  cl->constants = alloc(sizeof(Constant) * (cl->constants_count - 1));
  for (i = 0; i < cl->constants_count - 1; i++) 
    read_constant(cf,cl->constants+i);
  readu2(cf,&dummy2); /* skip access flags */
  readu2(cf,&cl->class_index);
  readu2(cf,&dummy2); /* skip super class */
  readu2(cf,&count); /* interfaces count */
  for (i = 0; i < count; i++)
    readu2(cf,&dummy2); /* skip interface index */
  readu2(cf,&count); /* fields count */
  for (i = 0; i < count; i++) {
    readu2(cf,&dummy2); /* skip access flags */
    readu2(cf,&dummy2); /* skip name index */
    readu2(cf,&dummy2); /* skip descriptor_index */
    skip_attributes(cf);
  }
  readu2(cf,&cl->methods_count); 
  cl->methods = alloc(sizeof(Method) * cl->methods_count);
  for (i = 0; i < cl->methods_count; i++)
    read_method(cf,cl,cl->methods+i);
  //* skip attributes */
  return cl;
}

