(* find out a sequence of PHYSICS to go from one state of the registers 
   to another state *)

module F = struct

  type state = int*int*int*int*int*int
  let success (s0,s1,s2,s3,d0,d1) =
    s0=1
(*
    s0=255 && (s1=2 || s2=2 || s3=2) && 
    (d0=255 && 10<=d1 && d1<=254 || d1=255 && 10<=d0 && d0<=254)
*)

  (* moves *)
  type move = int

  let move n (s0,s1,s2,s3,d0,d1) =
    assert (0 <= n && n <= 31);
    let sn = if n >= 16 then -16 + (n land 15) else n in
    let s0 = s0 + sn in
    let s0 = if s0 < 0 then s0 + 256 else if s0 > 255 then s0 - 256 else s0 in
    assert (0 <= s0 && s0 <= 255);
    (* extraction *)
    let l = if n land 1 > 0 then [d1] else [] in
    let l = if n land 2 > 0 then d0 :: l else l in
    let l = if n land 4 > 0 then s3 :: l else l in
    let l = if n land 8 > 0 then s2 :: l else l in
    let l = if n land 16 > 0 then s1 :: l else l in
    let l = s0 :: l in
    (* rotation *)
    let l = List.tl l @ [List.hd l] in
    (* new registers *)
    let s0,l = List.hd l, List.tl l in
    let s1,l = if n land 16 > 0 then List.hd l, List.tl l else s1,l in
    let s2,l = if n land 8 > 0 then List.hd l, List.tl l else s2,l in
    let s3,l = if n land 4 > 0 then List.hd l, List.tl l else s3,l in
    let d0,l = if n land 2 > 0 then List.hd l, List.tl l else d0,l in
    let d1,l = if n land 1 > 0 then List.hd l, List.tl l else d1,l in
    assert (l = []);
    (s0,s1,s2,s3,d0,d1)

  let moves st =
    List.map
      (fun n -> n, move n st)
      [1;2;3;4;
(*
       5;6;7;8;
       9;10;11;12;
       13;14;15;16;
       17;18;19;20;
       21;22;23;24;
       25;26;27;28;
*)
       29;30;31;21]

  (* visited states are put in a hash table using function [mark] *)
  type marked_state = state
  let mark s = Some s

end
open F
module S = Search.FunctionalIDS(F)

open Format

let s5 n =if n >= 16 then -16 + (n land 15) else n

let () = 
  let (s0,s1,s2,s3,d0,d1),ml = S.search (255,4,239,3,255,253) in
  printf "YES! (%d)@." (List.length ml);
  List.iter (fun p -> printf "PHYSICS %d;@." (s5 p)) ml;
  printf "%d,%d,%d,%d / %d,%d@." s0 s1 s2 s3 d0 d1