#!/usr/bin/python

# Interim code generation script.

import sys, os, mllib
from cStringIO import StringIO

out_dir=sys.argv[1]
out_pkg = sys.argv[2]
spec_file = sys.argv[3]

spec = mllib.xml_parse(spec_file)

def jbool(b):
  if b:
    return "true"
  else:
    return "false"

class Output:

  def __init__(self, dir, package, name):
    self.dir = dir
    self.package = package
    self.name = name
    self.lines = []

    self.line("package %s;" % self.package)
    self.line()
    self.line("import java.util.ArrayList;")
    self.line("import java.util.List;")
    self.line("import java.util.Map;")
    self.line("import java.util.UUID;")
    self.line()
    self.line("import org.apache.qpidity.transport.codec.Decoder;")
    self.line("import org.apache.qpidity.transport.codec.Encodable;")
    self.line("import org.apache.qpidity.transport.codec.Encoder;")
    self.line()
    self.line("import org.apache.qpidity.transport.network.Frame;")
    self.line()
    self.line()

  def line(self, l = ""):
    self.lines.append(l)

  def getter(self, type, method, value, pre = None):
    self.line()
    self.line("    public final %s %s() {" % (type, method))
    if pre:
      self.line("        %s;" % pre)
    self.line("        return %s;" % value)
    self.line("    }")

  def setter(self, type, method, variable, value = None, pre = None,
             post = None):
    if value:
      params = ""
    else:
      params = "%s value" % type
      value = "value"

    self.line()
    self.line("    public final %s %s(%s) {" % (self.name, method, params))
    if pre:
      self.line("        %s;" % pre)
    self.line("        this.%s = %s;" % (variable, value))
    if post:
      self.line("        %s;" % post)
    self.line("        return this;")
    self.line("    }")

  def write(self):
    dir = os.path.join(self.dir, *self.package.split("."))
    if not os.path.exists(dir):
      os.makedirs(dir)
    file = os.path.join(dir, "%s.java" % self.name)
    out = open(file, "w")
    for l in self.lines:
      out.write(l)
      out.write(os.linesep)
    out.close()

TYPES = {
  "longstr": "String",
  "shortstr": "String",
  "longlong": "long",
  "long": "long",
  "short": "int",
  "octet": "short",
  "bit": "boolean",
  "table": "Map<String,Object>",
  "timestamp": "long",
  "content": "String",
  "uuid": "UUID",
  "rfc1982-long-set": "RangeSet",
  "long-struct": "Struct",
  "signed-byte": "byte",
  "unsigned-byte": "short",
  "char": "char",
  "boolean": "boolean",
  "two-octets": "short",
  "signed-short": "short",
  "unsigned-short": "int",
  "four-octets": "int",
  "signed-int": "int",
  "unsigned-int": "long",
  "float": "float",
  "utf32-char": "char",
  "eight-octets": "long",
  "signed-long": "long",
  "unsigned-long": "long",
  "double": "double",
  "datetime": "long",
  "sixteen-octets": "byte[]",
  "thirty-two-octets": "byte[]",
  "sixty-four-octets": "byte[]",
  "_128-octets": "byte[]",
  "short-binary": "byte[]",
  "short-string": "String",
  "short-utf8-string": "String",
  "short-utf16-string": "String",
  "short-utf32-string": "String",
  "binary": "byte[]",
  "string": "String",
  "utf8-string": "String",
  "utf16-string": "String",
  "utf32-string": "String",
  "long-binary": "byte[]",
  "long-string": "String",
  "long-utf8-string": "String",
  "long-utf16-string": "String",
  "long-utf32-string": "String",
  "sequence": "List<Object>",
  "array": "List<Object>",
  "five-octets": "byte[]",
  "decimal": "byte[]",
  "nine-octets": "byte[]",
  "long-decimal": "byte[]",
  "void": "Void"
  }

DEFAULTS = {
  "longlong": "0",
  "long": "0",
  "short": "0",
  "octet": "0",
  "timestamp": "0",
  "bit": "false"
  }

TRACKS = {
  "connection": "Frame.L1",
  "session": "Frame.L2",
  "execution": "Frame.L3",
  None: None
  }

def camel(offset, *args):
  parts = []
  for a in args:
    parts.extend(a.split("-"))
  return "".join(parts[:offset] + [p[0].upper() + p[1:] for p in parts[offset:]])

def dromedary(s):
  return s[0].lower() + s[1:]

def scream(*args):
  return "_".join([a.replace("-", "_").upper() for a in args])


types = Output(out_dir, out_pkg, "Type")
types.line("public enum Type")
types.line("{")
codes = {}
for c in spec.query["amqp/constant"]:
  if c["@class"] == "field-table-type":
    name = c["@name"]
    if name.startswith("field-table-"):
      name = name[12:]
    if name[0].isdigit():
      name = "_" + name
    val = c["@value"]
    codes[val] = name
    if c["@width"] != None:
      width = c["@width"]
      fixed = "true"
    if c["@lfwidth"] != None:
      width = c["@lfwidth"]
      fixed = "false"
    types.line("    %s((byte) %s, %s, %s)," %
               (scream(name), val, width, fixed))
types.line("    ;")

types.line("    public byte code;")
types.line("    public int width;")
types.line("    public boolean fixed;")

types.line("    Type(byte code, int width, boolean fixed)")
types.line("    {")
for arg in ("code", "width", "fixed"):
  types.line("        this.%s = %s;" % (arg, arg))
types.line("    }")

types.line("    public static Type get(byte code)")
types.line("    {")
types.line("        switch (code)")
types.line("        {")
for code, name in codes.items():
  types.line("        case (byte) %s: return %s;" % (code, scream(name)))
types.line("        default: return null;")
types.line("        }")
types.line("    }")

types.line("}")
types.write()


const = Output(out_dir, out_pkg, "Constant")
const.line("public interface Constant")
const.line("{")
for d in spec.query["amqp/constant"]:
	name = d["@name"]
	val = d["@value"]
	datatype = d["@datatype"]
	if datatype == None:
		const.line("public static final int %s = %s;" % (scream(name), val))
const.line("}")
const.write()


DOMAINS = {}
STRUCTS = {}

for d in spec.query["amqp/domain"]:
  name = d["@name"]
  type = d["@type"]
  if type != None:
    DOMAINS[name] = d["@type"]
  elif d["struct"] != None:
    DOMAINS[name] = name
    STRUCTS[name] = camel(0, name)

def resolve(type):
  if DOMAINS.has_key(type) and DOMAINS[type] != type:
    return resolve(DOMAINS[type])
  else:
    return type

def jtype(type):
  if STRUCTS.has_key(type):
    return STRUCTS[type]
  else:
    return TYPES[type]

def jclass(jt):
  idx = jt.find('<')
  if idx > 0:
    return jt[:idx]
  else:
    return jt

REFS = {
  "boolean": "Boolean",
  "byte": "Byte",
  "short": "Short",
  "int": "Integer",
  "long": "Long",
  "float": "Float",
  "double": "Double",
  "char": "Character"
}

def jref(jt):
  return REFS.get(jt, jt)


OPTIONS = {}

class Struct:

  def __init__(self, node, name, base, type, size, pack, track, content):
    self.node = node
    self.name = name
    self.base = base
    self.type = type
    self.size = size
    self.pack = pack
    self.track = track
    self.content = content
    self.fields = []
    self.ticket = False

  def result(self):
    r = self.node["result"]
    if not r: return
    name = r["@domain"]
    if not name:
      name = self.name + "Result"
    else:
      name = camel(0, name)
    return name

  def field(self, type, name):
    if name == "ticket":
      self.ticket = True
    else:
      self.fields.append((type, name))

  def impl(self, out):
    out.line("public class %s extends %s {" % (self.name, self.base))

    out.line()
    out.line("    public static final int TYPE = %d;" % self.type)
    out.getter("int", "getStructType", "TYPE")
    out.getter("int", "getSizeWidth", self.size)
    out.getter("int", "getPackWidth", self.pack)
    out.getter("boolean", "hasTicket", jbool(self.ticket))

    if self.base == "Method":
      out.getter("boolean", "hasPayload", jbool(self.content))
      out.getter("byte", "getEncodedTrack", self.track)

    out.line()
    out.line("    private static final List<Field<?,?>> FIELDS = new ArrayList<Field<?,?>>();")
    out.line("    public List<Field<?,?>> getFields() { return FIELDS; }")
    out.line()

    out.line()
    for type, name in self.fields:
      out.line("    private boolean has_%s;" % name)
      out.line("    private %s %s;" % (jtype(type), name))

    if self.fields:
      out.line()
      out.line("    public %s() {}" % self.name)

    out.line()
    out.line("    public %s(%s) {" % (self.name, self.parameters()))
    opts = False
    for type, name in self.fields:
      if not OPTIONS.has_key(name):
        out.line("        %s(%s);" % (camel(1, "set", name), name))
      else:
        opts = True
    if opts:
      for type, name in self.fields:
        if OPTIONS.has_key(name):
          out.line("        boolean _%s = false;" % name)
      out.line("        for (int i=0; i < _options.length; i++) {")
      out.line("            switch (_options[i]) {")
      for type, name in self.fields:
        if OPTIONS.has_key(name):
          out.line("            case %s: _%s=true; break;" % (OPTIONS[name], name))
      out.line("            case NO_OPTION: break;")
      out.line('            default: throw new IllegalArgumentException'
               '("invalid option: " + _options[i]);')
      out.line("            }")
      out.line("        }")
      for type, name in self.fields:
        if OPTIONS.has_key(name):
          out.line("        %s(_%s);" % (camel(1, "set", name), name))
    out.line("    }")

    out.line()
    out.line("    public <C> void dispatch(C context, MethodDelegate<C> delegate) {")
    out.line("        delegate.%s(context, this);" % dromedary(self.name))
    out.line("    }")

    index = 0
    for type, name in self.fields:
      out.getter("boolean", camel(1, "has", name), "has_" + name)
      out.setter("boolean", camel(1, "clear", name), "has_" + name, "false",
                 post = "this.%s = %s; this.dirty = true" % (name, DEFAULTS.get(type, "null")))
      out.getter(jtype(type), camel(1, "get", name), name)
      for mname in (camel(1, "set", name), name):
        out.setter(jtype(type), mname, name,
                   post = "this.has_%s = true; this.dirty = true" % name)

      out.line()
      out.line('    static {')
      ftype = jref(jclass(jtype(type)))
      out.line('        FIELDS.add(new Field<%s,%s>(%s.class, %s.class, "%s", %d) {' %
               (self.name, ftype, self.name, ftype, name, index))
      out.line('            public boolean has(Object struct) {')
      out.line('                return check(struct).has_%s;' % name)
      out.line('            }')
      out.line('            public void has(Object struct, boolean value) {')
      out.line('                check(struct).has_%s = value;' % name)
      out.line('            }')
      out.line('            public %s get(Object struct) {' % ftype)
      out.line('                return check(struct).%s();' % camel(1, "get", name))
      out.line('            }')
      out.line('            public void read(Decoder dec, Object struct) {')
      if TYPES.has_key(type):
        out.line('                check(struct).%s = dec.read%s();' % (name, camel(0, type)))
      elif STRUCTS.has_key(type):
        out.line('                check(struct).%s = (%s) dec.readStruct(%s.TYPE);' %
                 (name, STRUCTS[type], STRUCTS[type]))
      else:
        raise Exception("unknown type: %s" % type)
      out.line('                check(struct).dirty = true;')
      out.line('            }')
      out.line('            public void write(Encoder enc, Object struct) {')
      if TYPES.has_key(type):
        out.line('                enc.write%s(check(struct).%s);' % (camel(0, type), name))
      elif STRUCTS.has_key(type):
        out.line('                enc.writeStruct(%s.TYPE, check(struct).%s);' %
                 (STRUCTS[type], name))
      else:
        raise Exception("unknown type: %s" % type)
      out.line('            }')
      out.line('        });')
      out.line('    }')
      index += 1;

    out.line("}")


  def parameters(self):
    params = []
    var = False
    for type, name in self.fields:
      if OPTIONS.has_key(name):
        var = True
      else:
        params.append("%s %s" % (jtype(type), name))
    if var:
      params.append("Option ... _options")
    return ", ".join(params)

  def arguments(self):
    args = []
    var = False
    for type, name in self.fields:
      if OPTIONS.has_key(name):
        var = True
      else:
        args.append(name)
    if var:
      args.append("_options")
    return ", ".join(args)

CLASSES = {"file": False, "basic": False, "stream": False, "tunnel": False}

PACK_WIDTHS = {
  None: 2,
  "octet": 1,
  "short": 2,
  "long": 4
  }

SIZE_WIDTHS = PACK_WIDTHS.copy()
SIZE_WIDTHS[None] = 0

class Visitor(mllib.transforms.Visitor):

  def __init__(self):
    self.structs = []
    self.untyped = -1

  def do_method(self, m):
    if CLASSES.get(m.parent["@name"], True):
      name = camel(0, m.parent["@name"], m["@name"])
      type = int(m.parent["@index"])*256 + int(m["@index"])
      self.structs.append((name, "Method", type, 0, 2, m))
      self.descend(m)

  def do_domain(self, d):
    s = d["struct"]
    if s:
      name = camel(0, d["@name"])
      st = s["@type"]
      if st in (None, "none", ""):
        type = self.untyped
        self.untyped -= 1
      else:
        type = int(st)
      self.structs.append((name, "Struct", type, SIZE_WIDTHS[s["@size"]],
                           PACK_WIDTHS[s["@pack"]], s))
    self.descend(d)

  def do_result(self, r):
    s = r["struct"]
    if s:
      name = camel(0, r.parent.parent["@name"], r.parent["@name"], "Result")
      type = int(r.parent.parent["@index"]) * 256 + int(s["@type"])
      self.structs.append((name, "Result", type, SIZE_WIDTHS[s["@size"]],
                           PACK_WIDTHS[s["@pack"]], s))
    self.descend(r)

v = Visitor()
spec.dispatch(v)

opts = Output(out_dir, out_pkg, "Option")
opts.line("public enum Option {")
structs = []
for name, base, typecode, size, pack, m in v.structs:
  struct = Struct(m, name, base, typecode, size, pack,
                  TRACKS.get(m.parent["@name"], "Frame.L4"),
                  m["@content"] == "1")
  for f in m.query["field"]:
    type = resolve(f["@domain"])
    name = camel(1, f["@name"])
    struct.field(type, name)
    if type == "bit":
      opt_name = scream(f["@name"])
      if not OPTIONS.has_key(name):
        OPTIONS[name] = opt_name
        opts.line("    %s," % opt_name)
  structs.append(struct)
opts.line("    %s," % "NO_OPTION")
opts.line("}")
opts.write()




for s in structs:
  impl = Output(out_dir, out_pkg, s.name)
  s.impl(impl)
  impl.write()

fct = Output(out_dir, out_pkg, "StructFactory")
fct.line("class StructFactory {")
fct.line("    public static Struct create(int type) {")
fct.line("        switch (type) {")
for s in structs:
  fct.line("        case %s.TYPE:" % s.name)
  fct.line("            return new %s();" % s.name)
fct.line("        default:")
fct.line('            throw new IllegalArgumentException("type: " + type);')
fct.line("        }")
fct.line("    }")
fct.line("}");
fct.write()

dlg = Output(out_dir, out_pkg, "MethodDelegate")
dlg.line("public abstract class MethodDelegate<C> {")
for s in structs:
  dlg.line("    public void %s(C context, %s struct) {}" %
           (dromedary(s.name), s.name))
dlg.line("}")
dlg.write()

inv = Output(out_dir, out_pkg, "Invoker")
inv.line("public abstract class Invoker {")
inv.line()
inv.line("    protected abstract void invoke(Method method);")
inv.line("    protected abstract <T> Future<T> invoke(Method method, Class<T> resultClass);")
inv.line()
for s in structs:
  if s.base != "Method": continue
  dname = dromedary(s.name)
  result = s.result()
  if result:
    result_type = "Future<%s>" % result
  else:
    result_type = "void"
  inv.line("    public %s %s(%s) {" % (result_type, dname, s.parameters()))
  if result:
    inv.line("        return invoke(new %s(%s), %s.class);" %
             (s.name, s.arguments(), result))
  else:
    inv.line("        invoke(new %s(%s));" % (s.name, s.arguments()))
  inv.line("    }")
inv.line("}")
inv.write()
