Theory Sorting_Guarded_Partition

theory Sorting_Guarded_Partition
imports Sorting_Quicksort_Scheme
begin

(* TODO: Move *)
lemma slice_update_outside[simp]: "i{l..<h}  slice l h (xs[i:=x]) = slice l h xs"
  unfolding Misc.slice_def
  apply auto
  by (metis drop_take leI take_update_cancel)

lemma slice_eq_mset_upd_outside: "slice_eq_mset l h xs xs'  i{l..<h}  i<length xs'  slice_eq_mset l h (xs[i:=x]) (xs'[i:=x])"
  unfolding slice_eq_mset_def
  apply (auto simp: drop_update_swap not_le)
  by (metis drop_update_cancel drop_update_swap leI)



  
hide_const (open) Transcendental.pi ― ‹pi› is the implementation of p›, not some constant related to a circle ;)›

(* TODO: Move. Found useful for ATPs *)
lemma strict_itrans: "a < c  a < b  b < c" for a b c :: "_::linorder"
  by auto

(* Guarded partitioning scheme, using sentinels. *)  
  
  

context weak_ordering begin  
  
subsection ‹Hoare Partitioning Scheme›  


definition "ungrd_qsg_next_l_spec si xs p li  
  doN {
    ASSERT (li  si  si<length xs  xs!si  p);
    SPEC (λli'. lili'  li'  si  (i{li..<li'}. xs!i<p)  xs!li'p)
  }"

definition "ungrd_qsg_next_h_spec si xs p hi  
  doN {
    ASSERT (si < hi  hilength xs  xs!si  p);
    SPEC (λhi'. sihi'  hi'<hi  (i{hi'<..<hi}. xs!i>p)  xs!hi'p)
  }"
  
  
definition qsg_next_l :: "nat  'a list  'a  nat  nat nres" where            
  "qsg_next_l si xs p li  doN {
    monadic_WHILEIT (λli'. lili'  (i{li..<li'}. xs!i<p)  li' si) 
      (λli. doN { mop_cmp_idx_v xs li p}) (λli. do { ASSERT (li < si); RETURN (li + 1) }) li
  }"  

  
lemma qsg_next_l_refine: "(qsg_next_l,PR_CONST ungrd_qsg_next_l_spec)IdIdIdIdIdnres_rel"
  unfolding qsg_next_l_def ungrd_qsg_next_l_spec_def PR_CONST_def
  apply (intro fun_relI; clarsimp)
  subgoal for si xs p li
    apply (refine_vcg monadic_WHILEIT_rule[where R="measure (λli. si - li)"] split_ifI)
    apply clarsimp_all
    subgoal by (metis le_eq_less_or_eq wo_leD)
    subgoal by (metis atLeastLessThan_iff less_Suc_eq)
    subgoal by (metis diff_less_mono2 lessI)
    subgoal using wo_not_le_imp_less by blast
    done
  done


definition qsg_next_h :: "nat  'a list  'a  nat  nat nres" where
  "qsg_next_h si xs p hi  doN {
    ASSERT (hi>0);
    let hi = hi - 1;
    monadic_WHILEIT (λhi'. sihi'  hi'hi  (i{hi'<..hi}. xs!i>p))
      (λhi. doN { mop_cmp_v_idx xs p hi}) (λhi. doN { ASSERT(hi>0); RETURN (hi - 1)}) hi
  }"  

  
lemma qsg_next_h_refine: "(qsg_next_h,PR_CONST (ungrd_qsg_next_h_spec))  Id  Id  Id  Id  Idnres_rel"
  unfolding qsg_next_h_def ungrd_qsg_next_h_spec_def PR_CONST_def
  apply (refine_vcg monadic_WHILEIT_rule[where R="measure id"] split_ifI)
  apply (all (determ elim conjE exE)?)
  apply simp_all
  subgoal by (metis bot_nat_0.extremum_uniqueI gr0I wo_leD)
  subgoal by (metis One_nat_def le_step_down_nat wo_leD)
  subgoal by (metis Suc_le_eq Suc_pred greaterThanAtMost_iff less_Suc_eq less_Suc_eq_le)
  subgoal by (simp add: greaterThanAtMost_upt)
  subgoal using unfold_le_to_lt by presburger
  done
  

definition "ungrd_qsg_next_lh_spec si0 siN xs p li hi  doN {
  li  ungrd_qsg_next_l_spec siN xs p li;
  hi  ungrd_qsg_next_h_spec si0 xs p hi;
  RETURN (li,hi)
}"  

(* Situation at start, and after swaps *)
definition "qsg_part_assn1 li0 hi0 p xs0 li hi xs 
    0 < li0  li0hi0  hi0<length xs0  xs0!(li0-1)  p  p  xs0!hi0
   slice_eq_mset li0 hi0 xs xs0
   li0li  lihi  hihi0
   (i{li0..<li}. xs!i  p) 
   (i{hi..<hi0}. p  xs!i) 
"

definition "qsg_part_assn2 li0 hi0 p xs0 li hi xs 
    0 < li0  li0hi0  hi0<length xs0  xs0!(li0-1)  p  p  xs0!hi0
   slice_eq_mset li0 hi0 xs xs0
   li0li  lihi+1  hi<hi0  lihi0
   (i{li0..<li}. xs!i  p) 
   (i{hi<..<hi0}. p  xs!i) 
   xs!li  p
   xs!hi  p  
"

lemma qsg_part_12_aux:
  assumes SENTINELS: "xs ! (li0 - Suc 0)  p" "p  xs ! hi0"
  assumes LI_LE_HI: "li  hi"
  assumes LI'_BOUND: "li  li'" "li'  hi0"
  assumes HI'_BOUND: "li0 - Suc 0  hi'" "hi' < hi" 
  assumes 
  UPTO_LI: "i{li0..<li}. xs ! i  p" and
  LI_TO_LI': "i{li..<li'}. xs ! i < p" and
  DOWNTO_HI: "i{hi..<hi0}. p  xs ! i" and
  HI'_TO_HI: "i{hi'<..<hi}. p < xs ! i"
  shows "li'  Suc hi'" "(i{li0..<li'}. xs ! i  p)" "(i{hi'<..<hi0}. p  xs ! i)"
proof -

  show G1: "(i{li0..<li'}. xs ! i  p)" using UPTO_LI LI_TO_LI' by (auto simp: unfold_lt_to_le)
  
  show G2: "(i{hi'<..<hi0}. p  xs ! i)" using HI'_TO_HI DOWNTO_HI by (auto simp: unfold_lt_to_le)
  
  consider "li<li'" | "hi' < hi-1" | "li'=li" "hi'=hi-1"
    using hi' < hi li  li' by linarith
  then show "li'  Suc hi'" proof cases
    case 1
    hence "xs!(li'-1) < p" using LI_TO_LI' by simp
    moreover have "i{hi'<..hi0}. p  xs!i" using p  xs ! hi0
      by (metis G2 greaterThanAtMost_iff greaterThanLessThan_iff le_eq_less_or_eq)
    ultimately show ?thesis
      using li'  hi0
      apply clarsimp
      by (metis Suc_leD Suc_lessE xs ! (li' - 1) < p diff_Suc_1 greaterThanAtMost_iff le_def unfold_lt_to_le)
  next
    case 2
    hence "p < xs!(hi'+1)" using HI'_TO_HI by simp
    moreover have "(i{li0-1..<li'}. xs ! i  p)" using G1 xs ! (li0 - Suc 0)  p 
      apply clarsimp
      by (metis atLeastLessThan_iff le_antisym le_refl nat_le_Suc_less_imp nat_less_le nat_neq_iff)
    ultimately show ?thesis
      using li0 - Suc 0  hi'
      apply clarsimp
      by (meson atLeastLessThan_iff le_SucI le_def unfold_lt_to_le)
  next
    case 3
    then show ?thesis using lihi by linarith
  qed
qed  

lemma qsg_part_12: "qsg_part_assn1 li0 hi0 p xs0 li hi xs 
   ungrd_qsg_next_lh_spec (li0-1) hi0 xs p li hi  SPEC (λ(li',hi').
    qsg_part_assn2 li0 hi0 p xs0 li' hi' xs  hi'<hi
     )"
  unfolding ungrd_qsg_next_lh_spec_def ungrd_qsg_next_l_spec_def ungrd_qsg_next_h_spec_def
  apply refine_vcg
  unfolding qsg_part_assn1_def
  apply (clarsimp_all simp: slice_eq_mset_eq_length)
  subgoal by (metis hd_drop_conv_nth slice_eq_mset_def)
  subgoal by linarith
  subgoal for i
    apply (subst slice_eq_mset_nth_outside, assumption)
    apply auto
    by (metis diff_diff_cancel less_imp_diff_less slice_eq_mset_eq_length)
  subgoal for hi' li'  
    unfolding qsg_part_assn2_def
    apply (clarsimp intro!: qsg_part_12_aux)
    apply (blast dest: qsg_part_12_aux)
    done
  done  

  
  
definition "qsg_partition_aux li0 hi0 p xs0  doN {
  ― ‹Initialize›
  ASSERT (li0>0);
  (li,hi)  ungrd_qsg_next_lh_spec (li0-1) hi0 xs0 p li0 hi0;
  
  (xs,li,hi)  WHILEIT 
    (λ(xs,li,hi). qsg_part_assn2 li0 hi0 p xs0 li hi xs)
    (λ(xs,li,hi). li<hi) 
    (λ(xs,li,hi). doN {
      ASSERT(li<hi  li<length xs  hi<length xs  lihi);
      xs  mop_list_swap xs li hi;
      let li = li + 1;
      
      (li,hi)  ungrd_qsg_next_lh_spec (li0-1) hi0 xs p li hi;
      RETURN (xs,li,hi)
    }) 
    (xs0,li,hi);
  
  RETURN (xs,li)
}"  


definition "gpartition_spec li hi p xs xs' m  
    slice_eq_mset li hi xs' xs 
   m{li..hi}
   (i{li..<m}. xs'!i  p)
   (i{m..<hi}. p  xs'!i)"


definition "gpartition_SPEC li hi p xs  do {
  ASSERT (lihi  hilength xs);
  SPEC (λ(xs',m). gpartition_spec li hi p xs xs' m)
}"

lemma qsg_part1_init: "0 < li; hi < length xs; xs ! (li - Suc 0)  p; p  xs ! hi; li  hi  qsg_part_assn1 li hi p xs li hi xs"
  unfolding qsg_part_assn1_def
  by simp

lemma qsg_part2_in_bounds: 
  assumes "qsg_part_assn2 li0 hi0 p xs0 li hi xs" 
  shows "li<length xs" "hi<length xs"
  using assms unfolding qsg_part_assn2_def
  by (auto dest: slice_eq_mset_eq_length)
  
  
lemma qsg_part_21: "qsg_part_assn2 li0 hi0 p xs0 li hi xs; li < hi  qsg_part_assn1 li0 hi0 p xs0 (Suc li) hi (swap xs li hi)"
  unfolding qsg_part_assn2_def qsg_part_assn1_def
  apply (intro conjI)
  apply clarsimp_all
  subgoal by (metis atLeastLessThan_iff diff_diff_cancel less_Suc_eq_le less_imp_diff_less nat_less_le slice_eq_mset_eq_length swap_indep swap_nth1)
  subgoal by (metis Suc_diff_Suc diff_is_0_eq greaterThanLessThan_iff nat.simps(3) not_less_iff_gr_or_eq slice_eq_mset_eq_length strict_itrans swap_indep swap_nth2)
  done
  
lemma qsg_part_2fin: "qsg_part_assn2 li0 hi0 p xs0 li hi xs; ¬ li < hi  gpartition_spec li0 hi0 p xs0 xs li"  
  unfolding qsg_part_assn2_def gpartition_spec_def
  apply clarsimp
  by (metis atLeastLessThan_iff atLeastSucLessThan_greaterThanLessThan le_antisym linorder_le_less_linear nat_Suc_less_le_imp)
  
lemma qsg_partition_aux_correct:
  "0<li; hi<length xs; xs!(li-1)  p; p  xs!hi  qsg_partition_aux li hi p xs  gpartition_SPEC li hi p xs"  
  unfolding qsg_partition_aux_def gpartition_SPEC_def
  apply (refine_vcg WHILEIT_rule[where R="measure (λ(_,_,hi). hi)"] qsg_part_12[where xs0=xs])  
  apply clarsimp_all
  subgoal by (simp add: qsg_part1_init)
  subgoal by (simp add: qsg_part2_in_bounds)
  subgoal by (simp add: qsg_part2_in_bounds)
  subgoal by (simp add: qsg_part_21)
  subgoal by (simp add: qsg_part_2fin)
  done    

definition "qsg_partition li0 hi0 p xs0  do {
  ASSERT (li0+1<hi0  hi0length xs0);

  (e0,xs)  mop_idx_v_swap xs0 li0 (COPY p);
  let li=li0+1;
  
  let hi=hi0-1;
  (eN,xs)  mop_idx_v_swap xs hi (COPY p);

  (xs,m)  qsg_partition_aux li hi p xs;
  
  let li = li-1;

  let e0ok = e0  p;
  let eNok = p  eN;
  (_,xs)  mop_idx_v_swap xs li e0;
  (_,xs)  mop_idx_v_swap xs hi eN;
  
  ASSERT (slice_eq_mset li (hi+1) xs xs0);
  
  if e0ok  eNok then
    RETURN (xs,m)
  else if e0ok  ¬eNok then do {
    xs  mop_list_swap xs m hi;
    ASSERT (m<hi0);
    let m=m+1;
    RETURN (xs,m)
  } else if eNok then do {
    ASSERT (1m);
    let m=m-1;
    xs  mop_list_swap xs li m;
    RETURN (xs,m)
  } else do {
    xs  mop_list_swap xs li hi;
    RETURN (xs,m)
  }
  
}"


term mop_cmp_idx_v

definition "qsg_partition_wrapper li hi p xs  do {
  if li<hi then (
    if li+1<hi then qsg_partition li hi p xs
    else doN {
      ifN mop_cmp_idx_v xs li p then RETURN (xs,hi)
      else RETURN (xs,li)
    }
  ) else RETURN (xs,li)
}"


lemma qsg_partition_correct: "li+1<hi  qsg_partition li hi p xs  gpartition_SPEC li hi p xs"
  unfolding qsg_partition_def gpartition_SPEC_def
  apply (refine_vcg qsg_partition_aux_correct[THEN order_trans])
  apply simp_all
  apply clarsimp_all
  apply simp_all
  unfolding gpartition_SPEC_def
  apply refine_vcg
  apply clarsimp_all
  apply simp_all
  unfolding gpartition_spec_def
  apply clarsimp_all
  apply (simp_all add: slice_eq_mset_eq_length)
  subgoal for xs' m
    apply (drule slice_eq_mset_upd_outside[where i="li" and x="xs ! li"]; simp?)
    apply (drule slice_eq_mset_upd_outside[where i="hi - Suc 0" and x="xs ! (hi - Suc 0)"]; simp?)
    apply (simp add: list_update_swap)
    apply (erule slice_eq_mset_subslice)
    apply auto
    done
  subgoal for xs' m
    apply (intro conjI)
    subgoal by simp
    subgoal by (auto simp: nth_list_update' slice_eq_mset_eq_length)
    subgoal by (clarsimp simp: nth_list_update' slice_eq_mset_eq_length)
    done
  subgoal for xs' m
    apply (intro conjI)
    subgoal
      apply (simp (no_asm_simp) add: slice_eq_mset_eq_length swap_nth nth_list_update) 
      using connex
      by blast
    subgoal by (fastforce simp: slice_eq_mset_eq_length swap_nth nth_list_update)
    done
  subgoal for xs' m
    apply (intro conjI)
    subgoal by (metis Suc_le_D Suc_to_right atLeastLessThan_iff le_Suc_eq le_eq_less_or_eq nz_le_conv_less slice_eq_mset_swap(1) zero_less_Suc)
    subgoal by simp 
    subgoal by simp 
    subgoal by (fastforce simp add: slice_eq_mset_eq_length swap_nth nth_list_update) 
    subgoal
      apply (simp (no_asm_simp) add: slice_eq_mset_eq_length swap_nth nth_list_update)
      apply safe
      subgoal using connex apply blast done
      subgoal by (metis atLeastLessThan_iff le_def le_eq_less_or_eq nat_le_Suc_less_imp)
      subgoal using connex by blast
      subgoal using connex by blast
      subgoal by (metis Suc_le_D Suc_le_lessD Suc_to_right atLeastLessThan_iff nat_in_between_eq(2))
      by (meson atLeastLessThan_iff le_antisym linorder_not_le nat_le_Suc_less_imp)
    done  
  subgoal for xs' m
    apply (intro conjI)
    subgoal by simp 
    subgoal
      apply (simp (no_asm_simp) add: slice_eq_mset_eq_length swap_nth nth_list_update) 
      using connex
      by blast
    subgoal
      apply (simp (no_asm_simp) add: slice_eq_mset_eq_length swap_nth nth_list_update)
      by (meson atLeastLessThan_iff connex diff_is_0_eq' diffs0_imp_equal le_def nat_le_Suc_less_imp)
    done
  done

  
  sepref_register gpartition_SPEC
  
  lemma qsg_partition_wrapper_refine: "(qsg_partition_wrapper, PR_CONST gpartition_SPEC)  Id  Id  Id  Id  Idnres_rel"
    unfolding qsg_partition_wrapper_def
    apply (clarsimp split!: if_split intro!: nres_relI)
    subgoal by (simp add: qsg_partition_correct)
    unfolding gpartition_SPEC_def
    apply (all refine_vcg)
    apply simp_all
    unfolding gpartition_spec_def
    apply auto
    subgoal by (metis connex le_less_Suc_eq strict_itrans wo_leD)
    subgoal by (metis le_less_Suc_eq strict_itrans wo_leI)
    done
  
  
end  
  

context sort_impl_context begin

  
sepref_register ungrd_qsg_next_l_spec ungrd_qsg_next_h_spec 

(* TODO: We can get rid of the length xs restriction: the stopper element will always lie within <h, which is size_t representable! *)
sepref_definition qsg_next_l_impl [llvm_inline] is "uncurry3 (qsg_next_l)" :: "size_assnk *a (arr_assn)k *a elem_assnk *a size_assnk a size_assn"
  unfolding qsg_next_l_def PR_CONST_def
  apply (annot_snat_const "TYPE(size_t)")
  by sepref

lemmas [sepref_fr_rules] = qsg_next_l_impl.refine[FCOMP qsg_next_l_refine]  
  
term qsg_next_h

sepref_definition qsg_next_h_impl [llvm_inline] is "uncurry3 (qsg_next_h)" :: "size_assnk *a (arr_assn)k *a elem_assnk *a size_assnk a size_assn"
  unfolding qsg_next_h_def PR_CONST_def
  apply (annot_snat_const "TYPE(size_t)")
  by sepref
  
lemmas [sepref_fr_rules] = qsg_next_h_impl.refine[FCOMP qsg_next_h_refine]  
  

sepref_register qsg_partition_aux  
sepref_def qsg_partition_aux_impl (*[llvm_inline]*) is "uncurry3 (PR_CONST qsg_partition_aux)" 
  :: "[λ_. True]c size_assnk *a size_assnk *a elem_assnk *a (arr_assn)d 
     arr_assn ×a size_assn [λ(((_,_),_),ai) (r,_). r=ai]c"
  unfolding qsg_partition_aux_def PR_CONST_def ungrd_qsg_next_lh_spec_def
  apply (simp only: nres_monad_laws split)
  apply (annot_snat_const "TYPE(size_t)")
  by sepref

(* TODO: Move *)                  
lemma unfold_let_le: "(let x = ab in f x) = (let x = ¬ b<a in f x)"  
  by (simp add: unfold_le_to_lt)
        
(* TODO: Move *)                  
  
definition "list_guarded_swap xs i j  if ij then mop_list_swap xs i j else RETURN xs "  
  
lemma list_guarded_swap_refine[refine]: 
  " (xs,xs')Idlist_rel; (i,i')Id; (j,j')Id   list_guarded_swap xs i j Id (mop_list_swap xs' i' j')"
  unfolding list_guarded_swap_def
  apply simp
  apply refine_vcg
  by simp

sepref_register list_guarded_swap  
  
sepref_def list_guarded_swap_impl [llvm_inline] is "uncurry2 list_guarded_swap" 
  :: "[λ_. True]c arr_assnd *a size_assnk *a size_assnk  arr_assn [λ((p,_),_) r. r=p]c"
  unfolding list_guarded_swap_def
  by sepref

end

context sort_impl_copy_context begin
  
(* Some refinement to tame exploding sepref *)  
definition "qsg_partition_swap_back hi0 li hi m xs e0ok eNok  do {
  if e0ok  eNok then
    RETURN (xs,m)
  else if e0ok  ¬eNok then do {
    xs  list_guarded_swap xs m hi;
    ASSERT (m<hi0);
    let m=m+1;
    RETURN (xs,m)
  } else if eNok then do {
    ASSERT (1m);
    let m=m-1;
    xs  list_guarded_swap xs li m;
    RETURN (xs,m)
  } else do {
    xs  list_guarded_swap xs li hi;
    RETURN (xs,m)
  }

}"
  
definition "qsg_partition2 li hi0 p xs0  do {
  (e0,xs)  mop_idx_v_swap xs0 li (COPY p);
  ASSERT (li<hi0);
  let li=li+1;
  
  ASSERT (hi0>0);
  let hi=hi0-1;
  (eN,xs)  mop_idx_v_swap xs hi (COPY p);

  (xs,m)  qsg_partition_aux li hi p xs;
  
  ASSERT (li>0);
  let li = li-1;

  let e0ok = e0  p;
  let eNok = p  eN;
  (_,xs)  mop_idx_v_swap xs li e0;
  (_,xs)  mop_idx_v_swap xs hi eN;

  qsg_partition_swap_back hi0 li hi m xs e0ok eNok
}"

lemma qsg_partition2_refine: "(PR_CONST  qsg_partition2, PR_CONST qsg_partition)  Id  Id  Id  Id  Idnres_rel"
  unfolding qsg_partition2_def qsg_partition_def qsg_partition_swap_back_def PR_CONST_def
  apply refine_rcg
  apply refine_dref_type
  apply simp_all
  done

sepref_register qsg_partition_swap_back

sepref_def qsg_partition_swap_back_impl is "uncurry6 (qsg_partition_swap_back)" 
  :: "[λ_. True]c size_assnk *a size_assnk *a size_assnk *a size_assnk *a (arr_assn)d *a bool1_assnk *a bool1_assnk
     arr_assn ×a size_assn [λ((((((_,_),_),_),ai),_),_) (r,_). r=ai]c"
  unfolding qsg_partition_swap_back_def PR_CONST_def
  apply (annot_snat_const "TYPE(size_t)")
  by sepref_dbg_keep
  
  

sepref_register qsg_partition

sepref_definition qsg_partition_impl [llvm_code] (*[llvm_inline]*) is "uncurry3 (PR_CONST qsg_partition2)" 
  :: "[λ_. True]c size_assnk *a size_assnk *a elem_assnk *a (arr_assn)d 
     arr_assn ×a size_assn [λ(((_,_),_),ai) (r,_). r=ai]c"
  unfolding qsg_partition2_def PR_CONST_def unfold_let_le
  supply [[goals_limit = 1]]
  apply (annot_snat_const "TYPE(size_t)")
  apply sepref_dbg_keep
  done

lemmas qsg_partition_impl'_hnr[sepref_fr_rules] = qsg_partition_impl.refine[FCOMP qsg_partition2_refine]


sepref_register qsg_partition_wrapper
sepref_definition qsg_partition_wrapper_impl [llvm_code] is "uncurry3 (qsg_partition_wrapper)"
  :: "[λ_. True]c size_assnk *a size_assnk *a elem_assnk *a (arr_assn)d 
     arr_assn ×a size_assn [λ(((_,_),_),ai) (r,_). r=ai]c"
  unfolding qsg_partition_wrapper_def PR_CONST_def
  apply (annot_snat_const "TYPE(size_t)")
  by sepref


lemmas qsg_partition_wrapper_impl'_hnr[sepref_fr_rules] = qsg_partition_wrapper_impl.refine[FCOMP qsg_partition_wrapper_refine]  
  

end

end