(* Simple arithmetic expressions *)
type expr =
| EInt of int
| EAdd of expr * expr
| EMul of expr * expr
type reg = [ `L_Reg of int ]
type src = [ `L_Int of int ]
type instr =
| ILoad of reg * src (* dst, src *)
| IAdd of reg * reg * reg (* dst, src1, src2 *)
| IMul of reg * reg * reg (* dst, src1, src2 *)
let rec eval_expr = function
| EInt n -> n
| EAdd (e1, e2) -> (eval_expr e1) + (eval_expr e2)
| EMul (e1, e2) -> (eval_expr e1) * (eval_expr e2)
type regs = (int, int) Hashtbl.t
let rec run_instr (rs:regs) = function
| ILoad (`L_Reg r, `L_Int n) -> Hashtbl.add rs r n
| IAdd (`L_Reg r1, `L_Reg r2, `L_Reg r3) ->
Hashtbl.add rs r1 ((Hashtbl.find rs r2) + (Hashtbl.find rs r3))
| IMul (`L_Reg r1, `L_Reg r2, `L_Reg r3) ->
Hashtbl.add rs r1 ((Hashtbl.find rs r2) * (Hashtbl.find rs r3))
let run_prog (rs:regs) (p:instr list) = List.iter (run_instr rs) p
let next_reg =
let n = ref 0 in
fun () -> (let temp = !n in n := !n + 1; temp)
let rec comp_expr = function
| EInt n ->
let r = next_reg () in
(r, [ILoad (`L_Reg r, `L_Int n)])
| EAdd (e1, e2) ->
let (r1, p1) = comp_expr e1 in
let (r2, p2) = comp_expr e2 in
let r = next_reg () in
(r, p1 @ p2 @ [IAdd (`L_Reg r, `L_Reg r1, `L_Reg r2)])
| EMul (e1, e2) ->
let (r1, p1) = comp_expr e1 in
let (r2, p2) = comp_expr e2 in
let r = next_reg () in
(r, p1 @ p2 @ [IMul (`L_Reg r, `L_Reg r1, `L_Reg r2)])
(* Notes:
* Assumes an unbounded set of registers (actually, register set finite since
integers finite---could be a problem with machine-generated input).
- Better have a good register allocator!
* @ operation not good to use here; will yield O(n^2) algorithm. Instead,
build tree-shaped representation, then flatten at the end
* In add and mul, order among e1, e2 doesn't matter in this example, but
would if we had side effects.
* In add and mul, doesn't matter whether we choose the result reg before
or after (or between) computing the code for the sub-expressions.
*)
let input1 = EMul (EAdd (EInt 1, EInt 2), EInt 3)
let input2 = EAdd (EMul (EInt 1, EInt 2), EInt 3)
let test e =
let orig_res = eval_expr e in
let (r, p) = comp_expr e in
let regs = Hashtbl.create 7 in
let _ = run_prog regs p in
let new_res = Hashtbl.find regs r in
Printf.printf "Orig: %d, New: %d, %s\n" orig_res new_res
(if orig_res = new_res then "pass" else "FAIL")