Theory Mpat_Antiquot
section ‹Matching-Pattern Antiquotation›
theory Mpat_Antiquot
imports Refine_Util_Bootstrap1
begin
typedecl term_struct
typedecl typ_struct
typedecl sort
definition mpaq_AS_PAT :: "'a ⇒ 'a ⇒ 'a" (infixl "AS⇩p" 900)
  where "a AS⇩p s ≡ a"
definition mpaq_STRUCT :: "term_struct ⇒ 'a" where "mpaq_STRUCT _ = undefined"
abbreviation mpaq_AS_STRUCT :: "'a ⇒ term_struct ⇒ 'a" (infixl "AS⇩s" 900)
  where "a AS⇩s s ≡ a AS⇩p mpaq_STRUCT s"
consts
  mpaq_Const :: "string ⇒ typ_struct ⇒ term_struct"
  mpaq_Free :: "string ⇒ typ_struct ⇒ term_struct"
  mpaq_Var :: "string*nat ⇒ typ_struct ⇒ term_struct"
  mpaq_Bound :: "nat ⇒ term_struct"
  mpaq_Abs :: "string ⇒ typ_struct ⇒ term_struct ⇒ term_struct"
  mpaq_App :: "term_struct ⇒ term_struct ⇒ term_struct" (infixl "$$" 900)
consts
  mpaq_TFree :: "string ⇒ sort ⇒ typ_struct"
  mpaq_TVar :: "string*nat ⇒ sort ⇒ typ_struct"
  mpaq_Type :: "string ⇒ typ_struct list ⇒ typ_struct"
ML ‹
  
  local
    fun replace_dummy' (Const (@{const_name Pure.dummy_pattern}, T)) i =
          (Var (("_dummy_", i), T), i + 1)
      | replace_dummy' (Abs (x, T, t)) i =
          let val (t', i') = replace_dummy' t i
          in (Abs (x, T, t'), i') end
      | replace_dummy' (t $ u) i =
          let
            val (t', i') = replace_dummy' t i;
            val (u', i'') = replace_dummy' u i';
          in (t' $ u', i'') end
      | replace_dummy' a i = (a, i);
    fun prepare_dummies' ts = #1 (fold_map replace_dummy' ts 1);
    fun tcheck (Type (name,Ts)) tnames = 
        let
          val (Ts,tnames) = map_fold tcheck Ts tnames
        in 
          (Type (name,Ts),tnames)
        end
      | tcheck (TFree (name,_)) _ = 
          error ("Pattern contains free type variable: " ^ name)
      | tcheck (TVar ((name,idx),S)) tnames = 
          if String.isPrefix "'v_" name andalso not (member (op =) tnames name) 
          then (TVar ((unprefix "'v_" name,idx),S),name::tnames)
          else (TVar (("_",0),S),tnames)
    fun vcheck (Const (n,T)) (vnames,tnames) = 
        let 
          val (T,tnames) = tcheck T tnames
        in (Const (n,T),(vnames,tnames)) end
      | vcheck (Free (n,_)) _ = 
          error ("Pattern contains free variable: " ^ n)
      | vcheck (Var (("_dummy_",_),T)) (vnames,tnames) = 
        let
          val (T,tnames) = tcheck T tnames
        in (Var (("_",0),T),(vnames,tnames)) end
      | vcheck (Var ((name,idx),T)) (vnames,tnames) = 
        let
          val (T,tnames) = tcheck T tnames
          val _ = idx <> 0 andalso error ("Variable index greater zero: "
            ^ Term.string_of_vname (name,idx)) 
          val _ = member (op =) vnames name andalso error ("Non-linear pattern: "
            ^ Term.string_of_vname (name,idx))
        in (Var ((name,0),T),(name::vnames,tnames)) end
      | vcheck (Bound i) p = (Bound i,p)
      | vcheck (Abs ("uu_",T,t)) (vnames,tnames) = let
          val (T,tnames) = tcheck T tnames
          val (t,(vnames,tnames)) = vcheck t (vnames,tnames)
        in (Abs ("_",T,t),(vnames,tnames)) end
      | vcheck (Abs (x,T,t)) (vnames,tnames) = let
          val (T,tnames) = tcheck T tnames
          val (t,(vnames,tnames)) = vcheck t (vnames,tnames)
        in (Abs (x,T,t),(vnames,tnames)) end
      | vcheck (t1$t2) p = let
          val (t1,p) = vcheck t1 p
          val (t2,p) = vcheck t2 p
        in (t1$t2,p) end
    val read = Scan.lift (Args.mode "typs" )
      -- (Args.context -- Scan.lift (Parse.position Parse.embedded_inner_syntax))
        
    fun write (with_types, t) = let
      fun twr_aux (Type (name,Ts)) 
            = "Type (" ^ ML_Syntax.print_string name ^ ", " ^ ML_Syntax.print_list twr_aux Ts ^ ")"
        | twr_aux (TFree (name,_)) = "TFree (" ^ ML_Syntax.print_string name ^ ", _)"
        | twr_aux (TVar ((name,_),_)) = name
  
      val twr = if with_types then twr_aux else K "_"
      fun s_string (Var ((name,_),_)) = name
        | s_string t = case try HOLogic.dest_string t of
            SOME s => ML_Syntax.print_string s
          | NONE => raise TERM ("Expected var or string literal",[t])
      fun s_index (Var ((name,_),_)) = name
        | s_index t = case try HOLogic.dest_nat t of
            SOME n => ML_Syntax.print_int n
          | NONE => raise TERM ("Expected var or nat literal",[t])
      fun s_indexname (Var ((name,_),_)) = name
        | s_indexname (Const (@{const_name Pair},_)$n$i) =
            "(" ^ s_string n ^ ", " ^ s_index i ^ ")"
        | s_indexname t = raise TERM ("Expected var or indexname-pair",[t])
      fun s_typ (Var ((name,_),_)) = name
        | s_typ (Const(@{const_name "mpaq_TFree"},_)$name$typ) 
            = "TFree (" ^ s_string name ^ "," ^ s_typ typ ^ ")"
        | s_typ (Const(@{const_name "mpaq_TVar"},_)$name$typ)
            = "TVar (" ^ s_indexname name ^ "," ^ s_typ typ ^ ")"
        | s_typ (Const(@{const_name "mpaq_Type"},_)$name$args)
            = "Type (" ^ s_string name ^ "," ^ s_args args ^ ")"
        | s_typ t = raise TERM ("Expected var or type structure",[t])
      and s_args (Var ((name,_),_)) = name
        | s_args t = ( case try HOLogic.dest_list t of
            SOME l => ML_Syntax.print_list s_typ l
          | NONE => raise TERM ("Expected variable or type argument list",[t])
          )
      fun swr (Const(@{const_name "mpaq_Const"},_)$name$typ) 
            = "Const (" ^ s_string name ^ ", " ^ s_typ typ ^ ")"
        | swr (Const(@{const_name "mpaq_Free"},_)$name$typ)
            = "Free (" ^ s_string name ^ ", " ^ s_typ typ ^ ")"
        | swr (Const(@{const_name "mpaq_Var"},_)$name$typ)
            = "Var (" ^ s_indexname name ^ ", " ^ s_typ typ ^ ")"
        | swr (Const(@{const_name "mpaq_Bound"},_)$idx)
            = "Bound " ^ s_index idx
        | swr (Const(@{const_name "mpaq_Abs"},_)$name$T$t)
            = "Abs (" ^ s_string name ^ ", " ^ s_typ T ^ ", " ^ swr t ^ ")"
        | swr (Const(@{const_name "mpaq_App"},_)$t1$t2)
            = "(" ^ swr t1 ^ ") $ (" ^ swr t2 ^ ")"
        | swr t = raise TERM ("Expected var or term structure",[t])
      fun vwr (Const (name,T)) 
            = "Const (" ^ ML_Syntax.print_string name ^ ", " ^ twr T ^ ")"
        | vwr (Var ((name,_),_)) = name
        | vwr (Free (name,T)) = "Free (" ^ name ^ ", " ^ twr T ^ ")"
        | vwr (Bound i) = "Bound " ^ ML_Syntax.print_int i
        | vwr (Abs ("_",T,t)) = "Abs (_," ^ twr T ^ "," ^ vwr t ^ ")"
        | vwr (Abs (x,T,t)) = "Abs (" ^ x ^ "," 
                ^ x ^ "_T as " ^ twr T ^ "," 
                ^ vwr t ^ ")"
        | vwr (Const(@{const_name mpaq_STRUCT},_)$t) = (
            swr t
          )
        | vwr (trm as Const(@{const_name "mpaq_AS_PAT"},_)$t$s) = ( 
            case t of
              (Var ((name,_),_)) => name ^ " as (" ^ vwr s ^ ")"
            | _ => raise TERM ("_AS_ must have identifier on LHS",[trm])
          )
        | vwr (t1$t2) = "(" ^ vwr t1 ^ ") $ (" ^ vwr t2 ^ ")"
      val (t,_) = vcheck t ([],[])
      val s = vwr t
    in
      "(" ^ s ^ ")"
    end
    fun process (with_types,(ctxt,(raw_t,pos))) = (let
      val ctxt' = Context.proof_map (
        Syntax_Phases.term_check 90 "Prepare dummies" (K prepare_dummies')
      ) ctxt
      val t = Proof_Context.read_term_pattern ctxt' raw_t
      val res = write (with_types,t)
    in
      
      res
    end
    handle ERROR msg => error (msg ^ Position.here pos)
    )
  in
    val mpat_antiquot = read >> process
  end
›
setup ‹
  ML_Antiquotation.inline @{binding "mpat"} mpat_antiquot
›
subsection ‹Examples›
ML_val ‹
  fun dest_pair_singleton @{mpat "{(?a,_)}"} = (a)
    | dest_pair_singleton t = raise TERM ("dest_pair_singleton",[t])
  fun dest_nat_pair_singleton @{mpat (typs) "{(?a::nat,?b::nat)}"} = (a,b)
    | dest_nat_pair_singleton t = raise TERM ("dest_nat_pair_singleton",[t])
  fun dest_pair_singleton_T @{mpat (typs) "{(?a::_ ⇒ ?'v_Ta,?b::?'v_Tb)}"} = 
    ((a,Ta),(b,Tb))
    | dest_pair_singleton_T t = raise TERM ("dest_pair_singleton_T",[t])
  fun dest_pair_lambda @{mpat "{(λa _ _. ?Ta, λb. ?Tb)}"} = (a,a_T,b,b_T,Ta,Tb)
    | dest_pair_lambda t = raise TERM ("dest_pair_lambda",[t])
  fun foo @{mpat "[?a,?b AS⇩s mpaq_Bound ?i,?c AS⇩p [?n]]"} = 
    (a,b,i,c,n)
  | foo t = raise TERM ("foo",[t])
›
hide_type (open) term_struct typ_struct sort
end