-
Lars Hupel authoredLars Hupel authored
stateful_ops.ML 20.10 KiB
signature STATEFUL_OPS = sig
val load: string * string * (string * string) list -> string list
val oracle: unit -> unit
val fresh: string -> string
val prove: term list * string option -> string option
val declare: string list -> unit
val cases: (term * (string * typ) list) * (typ * (term * term) list) -> term
val equivalences: (string * string) list -> string list
val assume_equivalences: (string * string) list -> unit
val lemmas: (string * term) list -> string list
val assume_lemmas: (string * term) list -> unit
val functions: (string * (string * typ) list * (term * typ)) list -> string option
val datatypes: (string * string list * (string * (string * typ) list) list) list -> string option
val lookup_datatype: string * (string * string) list -> (string, (string * (term * term * term list)) list) Codec.either
val pretty: term -> string
val report: unit -> (string * string) list
val dump: unit -> (string * string) list
(* internal *)
val reset: theory option -> unit
val peek: unit -> local_theory option
val get_functions: theory -> (thm option * thm list) Symtab.table
end
structure Stateful_Ops: STATEFUL_OPS = struct
fun try_timeout' f = try_timeout 5 f (* global timeout for auxiliary operations *)
fun err_timeout msg f x =
case try_timeout 5 f x of
Exn.Res r => r
| Exn.Exn TimeLimit.TimeOut => raise Timeout msg
| Exn.Exn exn => reraise exn
datatype report =
Function of {heads: (string * string) list, raws: string list, specs: string list, inducts: (string * string) list} |
Datatype of string |
Theory of {import: string, scripts: (string * string) list} |
Equivalence of string |
Lemma of {name: string, prop: string} |
Oracle |
End
structure Reports = Theory_Data
(
type T = report list
val empty = []
fun merge _ = impossible "Reports.merge" @{here}
val extend = I
)
structure Functions = Theory_Data
(
type T = (thm option * thm list) Symtab.table
val empty = Symtab.empty
fun merge _ = impossible "Functions.merge" @{here}
val extend = I
)
structure Oracle = Theory_Data
(
type T = (cterm -> thm) option
val empty = NONE
fun merge _ = impossible "Oracle.merge" @{here}
val extend = I
)
fun add_report report = Reports.map (fn list => report :: list)
val add_fun = Functions.map o Symtab.update_new
val add_funs = fold add_fun
fun set_oracle thy =
let
val ((_, oracle), thy') = Thm.add_oracle (@{binding leon_oracle}, I) thy
in
Oracle.map (K (SOME oracle)) thy'
|> add_report Oracle
end
val get_oracle =
Proof_Context.theory_of
#> Oracle.get
#> the_error "expected oracle" @{here}
val get_functions = Functions.get
type state = local_theory option
val empty_state = NONE: state
val state = Synchronized.var "libisabelle.leon_state" empty_state
fun unset_parallel_proofs () =
Goal.parallel_proofs := 0
fun reset thy =
(unset_parallel_proofs (); Synchronized.change state (K (Option.map (Named_Target.theory_init) thy)))
fun peek () =
Synchronized.value state
local
fun not_loaded pos =
invalid_state "no theory loaded" pos
fun already_loaded lthy =
invalid_state ("theory " ^ Context.theory_name (Proof_Context.theory_of lthy) ^ "already loaded")
in
fun access_loaded f =
case Synchronized.value state of
SOME thy => f thy
| NONE => not_loaded @{here}
fun update_loaded f =
let
fun upd NONE = not_loaded @{here}
| upd (SOME thy) = let val (a, thy') = f thy in (a, SOME thy') end
in
Synchronized.change_result state upd
end
fun load (root_name, thy_name, scripts) =
let
val _ = unset_parallel_proofs ()
val _ =
if Execution.is_running_exec Document_ID.none then ()
else invalid_state "Must not be running in interactive mode" @{here}
val root = Thy_Info.get_theory root_name
fun load_script (name, txt) =
let
val thy = Theory.begin_theory (name, Position.none) [root]
in (name, Exn.interruptible_capture (script_thy Position.none txt) thy) end
fun classify (_, Exn.Res thy) = ([thy], [])
| classify (name, Exn.Exn _) = ([], [name])
fun upd (SOME lthy) = already_loaded lthy @{here}
| upd NONE =
let
val (good, bad) = map classify (Par_List.map load_script scripts)
|> split_list ||> flat |>> flat
fun thy () =
Theory.begin_theory (thy_name, Position.none) (root :: good)
|> add_report (Theory {import = root_name, scripts = scripts})
val res =
if null bad then
SOME (Named_Target.theory_init (thy ()))
else
NONE
in (bad, res) end
in Synchronized.change_result state upd end
end
fun oracle () =
update_loaded (Local_Theory.background_theory set_oracle #> pair ())
fun fresh n =
update_loaded (Variable.variant_fixes [n] #>> hd)
fun prove (ts, method) =
access_loaded (fn lthy =>
let
val (ts', lthy) = register_terms ts lthy
val prop = Balanced_Tree.make HOLogic.mk_disj ts'
fun tac ctxt =
case method of
NONE => prove_tac ctxt
| SOME src => HEADGOAL (method_tac @{here} src ctxt)
in
(* Assumption: all proofs are sequential *)
try (Goal.prove lthy [] [] prop) (fn {context, ...} => tac context)
|> Option.map (print_thm lthy)
end)
fun gen_lemmas prove props =
update_loaded (fn lthy =>
let
fun process (name, prop) lthy =
let
val ([prop'], lthy') = register_terms [prop] lthy
val binding = Binding.qualified true "lemma" (Binding.make (name, @{here}))
in
case prove lthy' prop' of
NONE => ([name], lthy)
| SOME thm =>
lthy (* sic *)
|> Local_Theory.note ((binding, @{attributes [simp]}), [thm]) |> snd
|> Local_Theory.background_theory (add_report (Lemma {name = name, prop = print_thm lthy thm}))
|> pair []
end
in
fold_map process props lthy
|>> flat
end)
val lemmas =
gen_lemmas (fn ctxt => fn goal =>
try_timeout' (Goal.prove ctxt [] [] goal) (fn {context, ...} => prove_tac context) |> Exn.get_res)
val assume_lemmas =
gen_lemmas (fn ctxt => fn goal =>
SOME (get_oracle ctxt (Thm.cterm_of ctxt goal)))
#> K ()
fun declare names =
update_loaded (fold Variable.declare_names (map (Free o rpair dummyT) names) #> pair ())
fun cases ((raw_scr, bounds), (expected_typ, clauses)) =
access_loaded (fn ctxt =>
let
fun dest_abs (name, typ) t = Term.dest_abs (name, typ, t)
val unbound = fold_map dest_abs bounds #> snd
val scr = Syntax.check_term ctxt (unbound raw_scr)
val scr_typ = fastype_of scr
val names = Variable.names_of ctxt
fun process_clause (raw_pat, raw_rhs) =
let
val pat = Syntax.check_term ctxt (Type.constraint scr_typ raw_pat)
val frees = all_undeclared_frees ctxt pat
val (free_names', ctxt') = Variable.add_fixes (map fst frees) ctxt
val frees' = map Free (free_names' ~~ map snd frees)
val frees = map Free frees
val (pat', raw_rhs') = apply2 (subst_free (frees ~~ frees')) (pat, unbound raw_rhs)
val rhs' =
Type.constraint expected_typ raw_rhs'
|> Syntax.check_term (Variable.declare_term pat' ctxt')
in
(pat', rhs')
end
val term =
map process_clause clauses
|> Case_Translation.make_case ctxt Case_Translation.Quiet names scr
|> Syntax.check_term ctxt
fun find_type (t $ u) name = merge_options (find_type t name, find_type u name)
| find_type (Abs (_, _, t)) name = find_type t name
| find_type (Free (name', typ)) name = if name = name' then SOME typ else NONE
| find_type _ _ = NONE
fun get_type (name, typ) =
if typ = dummyT then
the_default dummyT (find_type term name)
else
typ
val new_bounds = map fst bounds ~~ map get_type bounds
in
fold (Term.lambda o Free) new_bounds term |> strip_abs_body
end)
fun gen_equivalences prove eqs =
update_loaded (fn lthy =>
let
fun prepare (lhs, rhs) =
let
val lhs' = qualify lthy lhs
val (rule, specs) =
Symtab.lookup (Functions.get (Proof_Context.theory_of lthy)) lhs
|> the_error "equivalences" @{here}
val goal = mk_untyped_eq (Const (lhs', dummyT), Syntax.parse_term lthy rhs)
|> HOLogic.mk_Trueprop
|> Syntax.check_term lthy
in
(lhs, specs, goal, rule)
end
val prepared = map prepare eqs
fun process (name, specs, goal, rule) lthy =
let
val binding = Binding.qualified true "equiv" (Binding.make (name, @{here}))
in
case prove lthy goal rule of
NONE => (([name], []), lthy)
| SOME thm =>
lthy
|> Local_Theory.note ((Binding.empty, @{attributes [simp del]}), specs) |> snd
|> Local_Theory.note ((binding, @{attributes [simp]}), [thm]) |> snd
|> pair ([], [thm])
end
val ((bad, good), lthy') =
fold_map process prepared lthy
|>> split_list |>> apfst flat |>> apsnd flat
val lthy'' =
Local_Theory.background_theory (fold (add_report o Equivalence o print_thm lthy') good) lthy'
in
(bad, lthy'')
end)
val assume_equivalences =
gen_equivalences (fn lthy => fn goal => fn _ =>
SOME (get_oracle lthy (Thm.cterm_of lthy goal)))
#> K ()
val equivalences =
gen_equivalences (fn lthy => fn goal => fn rule =>
try_timeout' (Goal.prove lthy [] [] goal) (fn {context, ...} => equiv_tac rule context) |> Exn.get_res)
fun functions raw_funs =
update_loaded (fn lthy =>
let
fun transform (name, raw_args, (raw_rhs, expected_typ)) =
let
val lhs = list_comb (Free (name, dummyT), map Free raw_args)
val rhs = Type.constraint expected_typ raw_rhs
in
((Binding.make (name, @{here}), NONE, NoSyn),
HOLogic.mk_Trueprop (Const (@{const_name HOL.eq}, dummyT) $ lhs $ rhs))
end
val no_args = exists (equal 0 o length o #2) raw_funs
val specs =
map transform raw_funs
|> split_list
||> map (Syntax.check_term lthy)
||> length raw_funs > 1 ? homogenize_raw_types lthy
||> map (pair (Binding.empty, []))
|> op ~~
val recursive =
case specs of
[(_, (_, prop))] =>
let
val ((f, _), rhs) = HOLogic.dest_eq (HOLogic.dest_Trueprop prop) |>> strip_comb
in
exists (equal f o Free) (all_frees rhs)
end
| _ => true
fun fun_construction lthy =
let
val (heads, props) = split_list specs
val construction = Function.add_function heads props Function_Fun.fun_config pat_completeness_auto_tac
val (_, lthy) = err_timeout "function construction" construction lthy
val (info, lthy) = err_timeout "termination proof" (Function.prove_termination NONE (termination_tac lthy)) lthy
in
#simps info
|> the_error "simps of function definition" @{here}
|> rpair lthy
end
fun def_construction lthy =
let val [(head, prop)] = specs in
Specification.definition (SOME head, prop) lthy
|>> snd |>> snd |>> single
end
fun construction lthy = lthy |>
(case (recursive, no_args) of
(true, true) => raise Unsupported ("Mutual definition without arguments: " ^ commas (map #1 raw_funs))
| (true, false) => fun_construction
| (false, _) => def_construction)
in
case Exn.interruptible_capture construction lthy of
Exn.Res (thms, lthy) =>
let
val names = map #1 raw_funs
val get_typ =
Thm.prop_of #> HOLogic.dest_Trueprop #> HOLogic.dest_eq #> fst
#> strip_comb #> fst
#> fastype_of
#> print_typ lthy
val typs = map get_typ thms
val lthy' = Local_Theory.restore lthy
val thms = Proof_Context.export lthy lthy' thms
val simpss = thms
|> map (try_case_to_simps lthy' o restore_eq o full_simplify (mk_simp_ctxt lthy'))
|> map (homogenize_spec_types lthy')
val simps = maps if_to_cond_simps (flat simpss)
val specs = names ~~ map (`(try_timeout' (mk_induction_schema lthy') #> Exn.get_res)) simpss
fun print_induct (_, (NONE, _)) = NONE
| print_induct (name, (SOME rule, _)) = SOME (name, print_thm lthy' rule)
val binding = Binding.qualified true "spec" (Binding.make (hd names, @{here}))
val report =
Function
{heads = names ~~ typs,
inducts = map_filter print_induct specs,
specs = map (print_thm lthy') simps,
raws = map (print_thm lthy') thms}
val lthy'' = lthy'
|> Local_Theory.note ((Binding.empty, @{attributes [simp del]}), thms) |> snd
|> Local_Theory.note ((binding, @{attributes [simp]}), simps) |> snd
|> Local_Theory.background_theory (add_funs specs #> add_report report)
in
(NONE, lthy'')
end
| Exn.Exn (exn as Impossible _) =>
reraise exn
| Exn.Exn exn =>
(SOME (print_exn exn), lthy)
end)
fun datatypes raw_dts =
update_loaded (fn lthy =>
let
val raw_dts = map (apfst (qualify lthy o #1) o dupl) raw_dts
fun mk_edges (qname, (_, _, constrs)) =
let
fun collect_types (Type (name, typs)) = name :: maps collect_types typs
| collect_types _ = []
val deps = constrs |> maps snd |> maps (collect_types o snd)
in map (pair qname) deps end
fun free tp = TFree ("'" ^ tp, @{sort type})
fun homogenize ((dt as (_, tps, _)) :: dts) =
let
fun subst tps' = typ_subst_atomic (map free tps' ~~ map free tps)
fun replace (name, tps', constrs) = (name, tps, map (apsnd (map (apsnd (subst tps')))) constrs)
in
dt :: map replace dts
end
val sccs = compute_sccs raw_dts (maps mk_edges raw_dts)
fun check_arity_dts dts = length (distinct op = (map (length o #2) dts)) = 1
val check_arity = not (exists (equal false) (map check_arity_dts sccs))
fun transform_tp raw_tp =
(SOME (Binding.name raw_tp), (free raw_tp, @{sort type}))
fun transform_field (name, typ) =
(Binding.name ("s" ^ name), typ)
fun transform_constr (name, raw_fields) =
(((Binding.name ("d" ^ name), Binding.name ("c" ^ name)), map transform_field raw_fields), NoSyn)
fun transform (name, raw_tps, raw_constrs) =
let
val dt = ((map transform_tp raw_tps, Binding.make (name, @{here})), NoSyn)
val constrs = map transform_constr raw_constrs
in
(((dt, constrs), (@{binding map}, @{binding rel})), [])
end
fun print lthy dts =
let
fun print_tp tp = "'" ^ tp
fun print_tps [] = ""
| print_tps tps = enclose "(" ")" (commas (map print_tp tps))
fun print_field (name, typ) =
"(s" ^ name ^ ": " ^ quote (print_typ lthy typ) ^ ")"
fun print_constr (name, fields) =
"d" ^ name ^ ": c" ^ name ^ " " ^ space_implode " " (map print_field fields)
fun print_one (name, tps, constrs) =
print_tps tps ^ " " ^ name ^ " = " ^ space_implode " | " (map print_constr constrs)
in
"datatype " ^ space_implode " and " (map print_one dts)
end
fun setup_dts scc lthy =
let
val input = (Ctr_Sugar.default_ctr_options, map transform scc)
val lthy =
BNF_FP_Def_Sugar.co_datatypes BNF_Util.Least_FP BNF_LFP.construct_lfp input lthy
|> Local_Theory.restore
fun get_sugar (name, _, _) =
qualify lthy name |> Ctr_Sugar.ctr_sugar_of lthy
|> the_error ("ctr_sugar of " ^ name) @{here}
val sugars = map get_sugar scc
fun get_splits {split, split_asm, ...} = [split, split_asm]
fun get_cases {case_thms, ...} = case_thms
in
lthy
|> Local_Theory.note ((Binding.empty, @{attributes [split]}), maps get_splits sugars) |> snd
|> Local_Theory.note ((Binding.empty, @{attributes [leon_unfold]}), maps get_cases sugars) |> snd
end
in
if check_arity then
let
val sccs = map homogenize sccs
val lthy' = fold setup_dts sccs lthy
val lthy'' = Local_Theory.background_theory (fold (add_report o Datatype o print lthy') sccs) lthy'
in (NONE, lthy'') end
else
(SOME "Unsupported feature: mismatched datatype arity", lthy)
end)
fun lookup_datatype (name, raw_constructors) =
update_loaded (fn lthy =>
case Ctr_Sugar.ctr_sugar_of lthy name of
NONE =>
(Codec.Left ("unknown datatype " ^ name), lthy)
| SOME {ctrs, discs, selss, split, split_asm, case_thms, ...} =>
let
val raw_constructors = sort_by snd (raw_constructors)
val (names, terms) =
map (fst o dest_Const) ctrs ~~ (discs ~~ selss)
|> sort_by fst
|> split_list
fun process (internal_name, name) (disc, sels) =
(internal_name, (Const (name, dummyT), dummify_types disc, map dummify_types sels))
val lthy' =
lthy
|> Local_Theory.note ((Binding.empty, @{attributes [split]}), [split, split_asm]) |> snd
|> Local_Theory.note ((Binding.empty, @{attributes [leon_unfold]}), case_thms) |> snd
in
if map snd raw_constructors <> names then
(Codec.Left ("constructor name mismatch: declared constructors are " ^ commas names), lthy)
else
(Codec.Right (map2 process raw_constructors terms), lthy')
end)
fun pretty t =
access_loaded (fn lthy => print_term lthy t)
fun report () =
access_loaded (fn lthy =>
let
fun print_report (Theory {scripts, ...}) = map (pair "Script" o fst) scripts
| print_report (Equivalence prop) = [("Equivalence", prop)]
| print_report (Datatype spec) = [("Datatype", spec)]
| print_report (Lemma {name, prop}) = [("Lemma", name ^ ": " ^ prop)]
| print_report (Function {raws, inducts, specs, ...}) =
map (pair "Function (raw)") raws @
map (pair "Function (spec)") specs @
map (fn (name, induct) => ("Function (induct)", name ^ ": " ^ induct)) inducts
| print_report Oracle = [("Oracle", "enabled")]
in
Proof_Context.theory_of lthy
|> Reports.get
|> rev
|> maps print_report
end)
fun dump () =
access_loaded (fn lthy =>
let
fun print_thy_header name imports =
"theory " ^ name ^ "\n" ^
"imports " ^ space_implode " " imports ^ "\n" ^
"begin\n"
fun print_thy name import text =
print_thy_header name [import] ^ text ^ "\nend\n"
val name = Context.theory_name (Proof_Context.theory_of lthy)
fun print_report (Theory {scripts, import}) =
(map (fn (script_name, text) => (script_name, print_thy script_name import text)) scripts,
print_thy_header name (import :: map fst scripts))
| print_report (Equivalence prop) = ([], "lemma [simp]: " ^ quote prop ^ "\nsorry")
| print_report (Lemma {prop, ...}) = ([], "lemma [simp]: " ^ quote prop ^ "\nsorry")
| print_report (Datatype spec) = ([], spec)
| print_report Oracle = ([], "")
| print_report End = ([], "\nend\n")
| print_report (Function {heads, specs, ...}) =
let
fun mk_head (name, typ) = quote name ^ " :: " ^ quote typ
val head = "fun " ^ space_implode " and " (map mk_head heads) ^ " where\n"
val body = space_implode "|\n" (map quote specs)
in
([], head ^ body)
end
in
Proof_Context.theory_of lthy
|> Reports.get
|> cons End
|> rev
|> map print_report
|> split_list
|>> flat
||> space_implode "\n\n"
||> pair name
|> swap
|> op ::
end)
end