(*  Copyright 2006 Hendrik Tews, All rights reserved.                  *)
(*  See file license.txt for terms of use                              *)
(***********************************************************************)

open Printf
open More_string
open Meta_ast
open Ast_config


(******************************************************************************
 ******************************************************************************
 *
 * identifier utilities
 *
 ******************************************************************************
 ******************************************************************************)

let rec reflection_function_type_component = function
  | AT_list(lk, ml_type, _) ->
      (reflection_function_type_component ml_type) ^
	(match lk with
	   | LK_ast_list -> "_ast_list"
	   | LK_fake_list -> "_fake_list"
	   | LK_obj_list -> "_obj_list"
	   | LK_sobj_list -> "_sobj_list"
	   | LK_sobj_set -> "_sobj_set"
	   | LK_string_obj_dict -> "_obj_dict"
	   | LK_string_ref_map -> "_string_ref_map"
	)

  | AT_option(ml_type, _) ->
      (reflection_function_type_component ml_type) ^ "_option"

  | AT_node cl -> cl.ac_name
  | AT_base type_name -> type_name

  | AT_ref _ -> assert false


let name_of_reflection_function ml_type = 
  "ocaml_reflect_" ^ (reflection_function_type_component (unref_type ml_type))

(* this is the variant of name_of_reflection_function used for producing
 * the variant of record variants 
 *)
let name_of_reflection_record_variant cl =
  sprintf "ocaml_reflect_%s_variant" cl.ac_name


(******************************************************************************
 ******************************************************************************
 *
 * type defintion
 *
 ******************************************************************************
 ******************************************************************************)

let record_field_hash = Hashtbl.create 1901

let check_record_field context field_name =
  if Hashtbl.mem record_field_hash field_name then begin
    eprintf "Warning: record label %s used both in %s and %s\n"
      field_name
      context
      (Hashtbl.find record_field_hash field_name)
  end
  else
    Hashtbl.add record_field_hash field_name context;
  field_name


(* remember tags that variant constructors get from the type definition *)
let constructor_tags = Hashtbl.create 557


let generate_record_type oc cl =
  output_string oc "{\n";
  (* annotation field *)
  fprintf oc "  %s : 'a;\n" 
    (check_record_field (translated_class_name cl) (annotation_field_name cl));
  List.iter
    (List.iter
       (fun field ->
	  fprintf oc "  %s : %s;\n"
	    (check_record_field
	       (translated_class_name cl)
	       (translated_field_name cl field))
	    (string_of_ast_type true field.af_mod_type)
       ))
    (get_all_fields cl);
  output_string oc "}\n"


let generate_variant_decl_tuple oc cl_sub =
  let fields = get_all_fields cl_sub
  in
    fprintf oc "  | %s of 'a" 
      (variant_name cl_sub);
    List.iter
      (List.iter
	 (fun field ->
	    output_string oc " * ";
	    output_string oc (string_of_ast_type true field.af_mod_type)))
      fields;
    output_string oc "\n"
  

let generate_variant_decl_record oc cl_sub =
  fprintf oc "  | %s of %s\n"
    (variant_name cl_sub) (string_of_ast_type true (AT_node cl_sub))


let generate_variant_decl oc variant_count cl_sub =
  if cl_sub.ac_record 
  then generate_variant_decl_record oc cl_sub
  else generate_variant_decl_tuple oc cl_sub;
  (* remember tag for this constructor *)
  Hashtbl.add constructor_tags cl_sub.ac_name !variant_count;
  incr variant_count
  

let type_def_header first oc cl =
  if !first then begin
    first := false;
    pr_comment oc ["syntax tree type definition"; ""];
    output_string oc "type "
  end
  else 
    output_string oc "and ";
  fprintf oc "%s = " (string_of_ast_type true (AT_node cl))


let generate_type_def first oc super = 
  type_def_header first oc super;
  (match super.ac_subclasses with
     | [] -> generate_record_type oc super
     | subclasses ->
	 let variant_count = ref 0
	 in
	   output_string oc "\n";
	   List.iter 
	     (generate_variant_decl oc variant_count) 
	     subclasses
  );
  output_string oc "\n\n";
  (match super.ac_subclasses with
     | [] -> ()
     | subclasses ->
	 List.iter 
	   (fun cl ->
	      type_def_header first oc cl;
	      generate_record_type oc cl;
	      output_string oc "\n\n";
	   )
	   (List.filter (fun ac -> ac.ac_record) subclasses)
  )


let generate_type_def_file source ast oc =
  pr_comment oc 
    [do_not_edit_line;
     "";
     "automatically generated by gen_reflection from " ^ source;
     "";
     "************************************************************************";
     "********************** Ast type definition *****************************";
     "************************************************************************";
    ];
  output_string oc "\n\n";

  if get_ocaml_type_header() <> [] 
  then
    pr_comment oc ["header from control file"; ""];
  List.iter 
    (fun line ->
       output_string oc line;
       output_string oc "\n"
    ) 
    (get_ocaml_type_header());
  output_string oc "\n\n";

  let first = ref true
  in
    List.iter (generate_type_def first oc) ast
    


(******************************************************************************
 ******************************************************************************
 *
 * ocaml_reflection : declarations and function register
 *
 ******************************************************************************
 ******************************************************************************)

let reflection_type_hash = Hashtbl.create 571

let string_refl_func_header ml_type =
  (* 
   * The annot field of MemberInit, Handler, FullExpression and Initializer 
   * is private and the respective getAnnot methods are not const. 
   * Therefore the whole reflection is not const.
   * sprintf "value %s(%s const * x, Ocaml_reflection_data * data)"
   *)
  let (x_type, implicit_pointer) =
    (match ml_type with
       | AT_list(lk, _, c_type) -> 
	   (sprintf "%s<%s>" 
	      (match lk with
		 | LK_ast_list -> "ASTList"
		 | LK_fake_list -> "FakeList"
		 | LK_obj_list -> "ObjList"
		 | LK_sobj_list -> "SObjList"
		 | LK_sobj_set -> "SObjSet"
		 | LK_string_obj_dict -> "StringObjDict"
		 | LK_string_ref_map -> "StringRefMap"
	      )
	      c_type,
	    false)

       | AT_option(_, c_type) -> (c_type, is_implicit_pointer_type c_type)
       | AT_node cl -> (cl.ac_name, false)
       | AT_base type_name -> (type_name, is_implicit_pointer_type type_name)

       | AT_ref _ -> assert false
    )
  in
    sprintf "value %s(%s %sx)"
      (name_of_reflection_function ml_type)
      x_type
      (if implicit_pointer then "" else "* ")



let rec generate_refl_decl oc ml_type =
  if not (Hashtbl.mem reflection_type_hash ml_type)
  then begin
    Hashtbl.add reflection_type_hash ml_type ();
    (* register function and output header *)
    (match ml_type with
       | AT_list _
       | AT_option _ 
       | AT_node _ ->
	   output_string oc (string_refl_func_header ml_type);
	   output_string oc ";\n";
       | AT_base _ ->
	   pr_c_comment oc
	     ["relying on extern";
	      string_refl_func_header ml_type]
       | AT_ref _ -> assert false
    );
    (* check for needed utility functions *)
    match ml_type with
      | AT_list(_, inner, _) 
      | AT_option(inner, _) -> generate_refl_decl oc inner
      | AT_node cl ->
	  (match cl.ac_super with
	     | Some super -> generate_refl_decl oc (AT_node super)
	     | None ->
		 List.iter 
		   (fun sub -> generate_refl_decl oc (AT_node sub))
		   cl.ac_subclasses
	  )
      | AT_base _ -> ()
      | AT_ref _ -> assert false
  end


let generate_reflection_decls_top_nodes oc =
  List.iter
    (fun top -> 
       try
	 generate_refl_decl oc (AT_node (get_node top))
       with
	 | Not_found ->
	     eprintf "Warning: Requested top node %s not found\n" top
    )
    (get_ast_top_nodes())


let generate_reflection_decls oc node =
  let cls = node :: node.ac_subclasses in
  let fields_ll =
    List.map
      (fun cl -> [cl.ac_args; cl.ac_last_args; cl.ac_fields])
      cls
  in
    List.iter
      (List.iter
	 (List.iter
	    (fun field ->
	       generate_refl_decl oc (unref_type field.af_mod_type))))
      fields_ll
      

(******************************************************************************
 ******************************************************************************
 *
 * ocaml_reflection : function definitions
 *
 ******************************************************************************
 ******************************************************************************)

let private_accessor obj cl field =
  try
    get_private_accessor cl.ac_name field.af_name
  with
    | Not_found -> sprintf "%s->get%s()" obj (String.capitalize field.af_name)


let string_of_reflection_call ml_type var =
  let unref_call = 
    sprintf "%s(%s)" 
      (name_of_reflection_function ml_type) 
      var
  in
    match ml_type with
      | AT_ref _ -> sprintf "make_reference(%s)" unref_call
      | AT_base _
      | AT_node _
      | AT_option _
      | AT_list _ -> unref_call


(* Generate a reflection for a variant or for a record, depending
 * on cl.ac_super: if there is a super class do a variant, otherwise
 * a record.
 *)
let generate_block_reflection oc shared_val_type cl =
  let out = output_string oc in
  let fpf format = fprintf oc format in
  let is_record = cl.ac_super = None or cl.ac_record in
  let fields = get_all_fields cl in
  let field_counter = ref 0 in
  let my_val_type = incr shared_val_type; !shared_val_type
  in 
    out "  ";
    pr_c_comment oc
      [sprintf "reflect %s into a %s" 
	 cl.ac_name
	 (if is_record then "record" else "variant")];
    out "\n";
    out "  CAMLparam0();\n";
    out "  CAMLlocal2(res, child);\n\n";

    fpf "  res = find_ocaml_shared_value(x, %d);\n" my_val_type;
    out "  if(res != Val_None) {\n";
    out "    xassert(Is_block(res) && Tag_val(res) == 0 &&";
    out " Wosize_val(res) == 1);\n";
    out "    CAMLreturn(Field(res, 0));\n";
    out "  }\n\n";

    fpf "  res = caml_alloc(%d, %d);\n"
      (count_fields fields +1) 
      (if is_record then 0 else Hashtbl.find constructor_tags cl.ac_name);
    fpf "  register_ocaml_shared_value(x, res, %d);\n\n" my_val_type;
    
    out "  child = ocaml_ast_annotation(x);\n";
    fpf "  Store_field(res, %d, child);\n\n" !field_counter;
    incr field_counter;

    List.iter
      (List.iter
	 (fun field ->
	    fpf "  child = %s;\n"
	      (string_of_reflection_call 
		 field.af_mod_type
		 (if field.af_is_private 
		  then
		    private_accessor "x" cl field
		  else
		    sprintf "%sx->%s"
		      (if field.af_is_pointer then "" else "&")
		      field.af_name));
	    fpf "  Store_field(res, %d, child);\n\n" !field_counter;
	    incr field_counter
	 ))
      fields;

    out "  CAMLreturn(res);\n";
    out "}\n\n\n"


let generate_record_variant_reflection oc shared_val_type recsub =
  let out = output_string oc in
  let fpf format = fprintf oc format in
  let my_val_type = incr shared_val_type; !shared_val_type
  in
    fpf "value %s(%s *x) {\n"
      (name_of_reflection_record_variant recsub)
      recsub.ac_name;
    out "  CAMLparam0();\n";
    out "  CAMLlocal2(res, child);\n\n";

    fpf "  res = find_ocaml_shared_value(x, %d);\n" my_val_type;
    out "  if(res != Val_None) {\n";
    out "    xassert(Is_block(res) && Tag_val(res) == 0 &&";
    out " Wosize_val(res) == 1);\n";
    out "    CAMLreturn(Field(res, 0));\n";
    out "  }\n\n";

    fpf "  res = caml_alloc(%d, %d);\n"
      1
      (Hashtbl.find constructor_tags recsub.ac_name);
    fpf "  register_ocaml_shared_value(x, res, %d);\n\n" my_val_type;
    
    fpf "  child = %s;\n" (string_of_reflection_call (AT_node recsub) "x");
    out "  Store_field(res, 0, child);\n\n";

    out "  CAMLreturn(res);\n";
    out "}\n\n\n"


let generate_downcast_reflection oc super =
  let out = output_string oc in
  let fpf format = fprintf oc format 
  in
    fpf "  switch(x->%s) {\n"
      (try
	 superclass_get_kind super.ac_name
       with
	 | Not_found -> "kind()"
      );    
    List.iter 
      (fun sub ->
	 let downcast_expression =
	   try
	     get_downcast sub.ac_name
	   with
	     | Not_found -> sprintf "x->as%s()" sub.ac_name
	 in
	   fpf "  case %s::%s:\n" 
	     super.ac_name
	     (try
		get_subclass_tag sub.ac_name
	      with
		| Not_found -> String.uppercase sub.ac_name
	     );
	   fpf "    return %s;\n\n"
	     (if sub.ac_record 
	      then
		sprintf "%s(%s)" 
		  (name_of_reflection_record_variant sub)
		  downcast_expression
	      else
		string_of_reflection_call 
		  (AT_node sub)
		  (* use non-const version, 
		   * see comment in string_refl_func_header *)
		  downcast_expression
	     )
      )
      super.ac_subclasses;
    out "  default:\n";
    out "    xassert(false);\n";
    out "    break;\n";
    out "  }\n";
    out "  xassert(false);\n";

    out "}\n\n\n"


let generate_list_refl_defn oc shared_val_type lk ml_type c_type =
  let out = output_string oc in
  let fpf format = fprintf oc format in
  let my_val_type = incr shared_val_type; !shared_val_type in
  let (list_combi, must_dereference, iter_name, iter_access) =
    match lk with
	(* use non-const versions, see comment in string_refl_func_header *)
      | LK_ast_list -> ("FOREACH_ASTLIST_NC", true, "iter", "iter.data()")
      | LK_fake_list -> ("FAKELIST_FOREACH_NC", false, "ptr", "ptr")
      | LK_obj_list -> ("FOREACH_OBJLIST_NC", true, "iter", "iter.data()")
      | LK_sobj_list -> ("SFOREACH_OBJLIST_NC", true, "iter", "iter.data()")
      | LK_sobj_set -> ("FOREACH_SOBJSET_NC", true, "iter", "iter.data()")
      | LK_string_obj_dict -> 
	  ("FOREACH_STRINGOBJDICT_NC", true, "iter", 
	   sprintf "const_cast<%s *>(iter.value())" c_type)
      | LK_string_ref_map -> assert false

  in
    out "  CAMLparam0();\n";
    out "  CAMLlocal4(res, tmp, elem, previous);\n\n";

    (* not necessary for ASTList and FakeList, but doesn't harm, so ... *)
    out "  if( x == NULL) CAMLreturn(Val_emptylist);\n\n";

    fpf "  res = find_ocaml_shared_value(x, %d);\n" my_val_type;
    out "  if(res != Val_None) {\n";
    out "    xassert(Is_block(res) && Tag_val(res) == 0 &&";
    out " Wosize_val(res) == 1);\n";
    out "    CAMLreturn(Field(res, 0));\n";
    out "  }\n\n";

    out "  previous = 0;\n\n";

    fpf "  %s(%s, %sx, %s) {\n" 
      list_combi c_type 
      (if must_dereference then "*" else "")
      iter_name;
    out "    tmp = caml_alloc(2, Tag_cons);\n";
    out "    if(previous == 0) {\n";
    out "      res = tmp;\n";
    fpf "      register_ocaml_shared_value(x, res, %d);\n" my_val_type;
    out "    } else {\n";
    out "      Store_field(previous, 1, tmp);\n";
    out "    }\n";
    fpf "    elem = %s;\n"
      (string_of_reflection_call ml_type iter_access);
    out "    Store_field(tmp, 0, elem);\n";
    out "    Store_field(tmp, 1, Val_emptylist);\n";
    out "    previous = tmp;\n";
    out "  }\n\n";

    out "  if(previous == 0){\n";
    out "    res = Val_emptylist;\n";
    fpf "    register_ocaml_shared_value(x, res, %d);\n" my_val_type;
    out "  }\n\n";

    out "  CAMLreturn(res);\n";
    out "}\n\n\n"


let generate_hash_refl_defn oc shared_val_type ml_type c_type =
  let out = output_string oc in
  let fpf format = fprintf oc format in
  let my_val_type = incr shared_val_type; !shared_val_type
  in
    out "  CAMLparam0();\n";
    out "  CAMLlocal5(res, hash_size, key, var, val);\n\n";

    fpf "  res = find_ocaml_shared_value(x, %d);\n" my_val_type;
    out "  if(res != Val_None) {\n";
    out "    xassert(Is_block(res) && Tag_val(res) == 0 &&";
    out " Wosize_val(res) == 1);\n";
    out "    CAMLreturn(Field(res, 0));\n";
    out "  }\n\n";

    out "  static value * hashtbl_create_closure = NULL;\n";
    out "  static value * hashtbl_add_closure = NULL;\n";
    out "  if(hashtbl_create_closure == NULL)\n";
    out "    hashtbl_create_closure = caml_named_value(\"hashtbl_create\");\n";
    out "  xassert(hashtbl_create_closure);\n";
    out "  if(hashtbl_add_closure == NULL)\n";
    out "    hashtbl_add_closure = caml_named_value(\"hashtbl_add\");\n";
    out "  xassert(hashtbl_add_closure);\n\n";

    out "  xassert(x->getNumEntries() <= Max_long &&";
    out " Min_long <= x->getNumEntries());\n";
    out "  hash_size = Val_int(x->getNumEntries());\n";
    out "  xassert(IS_OCAML_AST_VALUE(hash_size));\n";
    out "  res = caml_callback(*hashtbl_create_closure, hash_size);\n";
    out "  xassert(IS_OCAML_AST_VALUE(res));\n";
    fpf "  register_ocaml_shared_value(x, res, %d);\n\n" my_val_type;

    fpf "  FOREACH_STRINGREFMAP_NC(%s, *x, iter) {\n" c_type;
    out "    key = ocaml_reflect_cstring(iter.key());\n";
    fpf "    var = %s;\n"
      (string_of_reflection_call ml_type "iter.value()");
    out "    val = caml_callback3(*hashtbl_add_closure, res, key, var);\n";
    out "    xassert(val == Val_unit);\n";
    out "  }\n\n";
    
    out "  CAMLreturn(res);\n";
    out "}\n\n\n"


let generate_option_refl_defn oc shared_val_type ml_type =
  let out = output_string oc in
  let fpf format = fprintf oc format in
  let my_val_type = incr shared_val_type; !shared_val_type
  in
    out "  CAMLparam0();\n";
    out "  CAMLlocal2(res, elem);\n\n";

    out "  if(x == NULL) CAMLreturn(Val_None);\n\n";
    
    fpf "  res = find_ocaml_shared_value(x, %d);\n" my_val_type;
    out "  if(res != Val_None) {\n";
    out "    xassert(Is_block(res) && Tag_val(res) == 0 &&";
    out " Wosize_val(res) == 1);\n";
    out "    CAMLreturn(Field(res, 0));\n";
    out "  }\n\n";

    out "  res = caml_alloc(1, Tag_some);\n";
    fpf "  register_ocaml_shared_value(x, res, %d);\n" my_val_type;

    fpf "  elem = %s;\n" (string_of_reflection_call ml_type "x");
    out "  Store_field(res, 0, elem);\n";
    out "  CAMLreturn(res);\n";
    out "}\n\n\n"


  
let generate_refl_defn oc shared_val_type ml_type =
  output_string oc (string_refl_func_header ml_type);
  output_string oc " {\n";
  match ml_type with
    | AT_list(LK_string_ref_map, inner, c_type) ->
	generate_hash_refl_defn oc shared_val_type inner c_type
    | AT_list(lk, inner, c_type) ->
	generate_list_refl_defn oc shared_val_type lk inner c_type
    | AT_option(inner, _) ->
	generate_option_refl_defn oc shared_val_type inner
    | AT_node cl ->
	if cl.ac_subclasses = []
	then
	  generate_block_reflection oc shared_val_type cl
	else
	  generate_downcast_reflection oc cl
    | AT_base _ 
    | AT_ref _ -> assert false


let generate_ocaml_reflect_cc source header ast oc =
  let out = output_string oc in
  let shared_val_type = ref 0
  in
    pr_c_comment oc
      [do_not_edit_line;
       "";
       "automatically generated by gen_reflection from " ^ source;
       "";
       "**********************************************************************";
       "*************** Ocaml reflection tree traversal **********************";
       "**********************************************************************";
      ];
    out "\n\n";
    fprintf oc "#include \"%s\"\n" (Filename.basename header);
    out "\n\n";

    pr_c_comment oc
      ["**********************************************************************";
       "Reflection functions declarations";
      ];
    out "\n\n";
    generate_reflection_decls_top_nodes oc;
    List.iter (generate_reflection_decls oc) ast;
    output_string oc "\n\n";

    pr_c_comment oc 
      ["**********************************************************************";
       "Reflection functions definitions";
      ];
    output_string oc "\n\n";
    (* make the variants of record variants *)
    List.iter
      (fun super ->
	 List.iter
	   (generate_record_variant_reflection oc shared_val_type)
	   (List.filter (fun sub -> sub.ac_record) super.ac_subclasses)
      )
      ast;
    (* sort the reflection_type_hash *)
    List.iter 
      (fun (ml_type, _) -> match ml_type with
    	 | AT_list _
	 | AT_option _
    	 | AT_node _ -> generate_refl_defn oc shared_val_type ml_type
    	 | AT_base _ -> ()
	 | AT_ref _ -> assert false
      )
      (List.fast_sort
	 (fun (_,h1) (_,h2) -> compare h1 h2)
	 (Hashtbl.fold
	    (fun ml_type () res -> 
	       (ml_type, string_refl_func_header ml_type) :: res)
	    reflection_type_hash
	    []));
    out "\n"


(******************************************************************************
 ******************************************************************************
 *
 * ocaml_reflection : top nodes include file
 *
 ******************************************************************************
 ******************************************************************************)

let generate_ocaml_reflect_h source this_file oc =
  let out = output_string oc in
  let fpf format = fprintf oc format in
  let cpp_symbol = 
    String.uppercase (translate "." "_" (Filename.basename this_file))
  in
    pr_c_comment oc
      [do_not_edit_line;
       "";
       "automatically generated by gen_reflection from " ^ source;
       "";
       "**********************************************************************";
       "****************** Ocaml reflection top node *************************";
       "**********************************************************************";
      ];
    out "\n";
    fpf "#ifndef %s\n" cpp_symbol;
    fpf "#define %s\n" cpp_symbol;
    out "\n\n";

    if get_ocaml_reflect_header() <> [] 
    then
      pr_c_comment oc ["header from control file"; ""];
    List.iter 
      (fun line ->
	 out line;
	 out "\n"
      ) 
      (get_ocaml_reflect_header());
    out "\n\n";

    pr_c_comment oc
      ["**********************************************************************";
       "Reflection functions declarations for the top node";
      ];
    out "\n\n";

    List.iter
      (fun top -> 
	 try
	   output_string oc (string_refl_func_header (AT_node (get_node top)));
	   output_string oc ";\n";
	 with
	   | Not_found -> ()
      )
      (get_ast_top_nodes());

    out "\n\n";
    fpf "#endif // %s\n" cpp_symbol


(******************************************************************************
 ******************************************************************************
 *
 * arguments and main
 *
 ******************************************************************************
 ******************************************************************************)


let output_prefix_option = ref None

let arguments = 
  [
    ("-o", Arg.String (fun s -> output_prefix_option := Some s),
     "out set output prefix to out");
  ]


let main () =
  let (oast_file, ast) = setup_ml_ast arguments "gen_reflection" in
  let output_prefix = match !output_prefix_option with
    | Some prefix -> prefix
    | None -> 
	if Filename.check_suffix oast_file ".oast" 
	then
	  Filename.chop_suffix oast_file ".oast"
	else
	  oast_file
  in
  let file ext = output_prefix ^ ext 
  in
  let with_file ext comment action =
    let file_name = file ext in
    let oc = open_out file_name
    in
      action oc;
      close_out oc;
      eprintf "wrote %s (%s)\n" file_name comment
  in
  let header_ext = "_ocaml_reflect.h"
  in
    with_file "_type.ml" "type definition" 
      (generate_type_def_file oast_file ast);
    with_file "_ocaml_reflect.cc" "ocaml reflection traversal"
      (generate_ocaml_reflect_cc oast_file (file header_ext) ast);
    with_file header_ext "ocaml reflection top node declaration"
      (generate_ocaml_reflect_h oast_file (file header_ext));
;;


try
  main()
with
  | Assert_failure(file, line, char) ->
      eprintf "File \"%s\", line %d, character %d: Assertion failure\n"
	file line char;
      exit 1


