From Ltac2 Require Import Ltac2.
From Hammer Require Import Tactics.
From Equations Require Import Equations.
Require Export LP.lattice.

Set Default Proof Mode "Classic".

Inductive lexp (A : Set) : Set :=
| Var : A -> lexp A
| Meet : lexp A -> lexp A -> lexp A
| Join : lexp A -> lexp A -> lexp A.
Arguments Var {A}.
Arguments Meet {A}.
Arguments Join {A}.

Fixpoint denoteLexp {A : Set} {_:Lattice A} (e : lexp A) :=
  match e with
  | Var a => a
  | Meet e1 e2 => meet (denoteLexp e1) (denoteLexp e2)
  | Join e1 e2 => join (denoteLexp e1) (denoteLexp e2)
  end.

Fixpoint lexp_size {A : Set} (e : lexp A) :=
  match e with
  | Var _ => 0
  | Meet e1 e2 => 1 + lexp_size e1 + lexp_size e2
  | Join e1 e2 => 1 + lexp_size e1 + lexp_size e2
  end.

From Equations Require Import Equations.

#[tactic="sfirstorder"] Equations splitLeq {A : Set} `{Lattice A} (e1 : lexp A) (e2 : lexp A) : Prop
  by wf (lexp_size e1 + lexp_size e2) lt :=
  splitLeq (Var a1) (Var a2) => leq_lat a1 a2;
  splitLeq (Join e11 e12) e2 => splitLeq e11 e2 /\ splitLeq e12 e2;
  splitLeq e1 (Meet e21 e22) => splitLeq e1 e21 /\ splitLeq e1 e22;
  splitLeq e1 (Join e21 e22) => splitLeq e1 e21 \/ splitLeq e1 e22 \/ (leq_lat (denoteLexp e1) (denoteLexp (Join e21 e22))) ;
                               splitLeq (Meet e11 e12) e2 => splitLeq e11 e2 \/ splitLeq e12 e2 \/ (leq_lat (denoteLexp (Meet e11 e12)) (denoteLexp e2)).


#[tactic="sfirstorder"] Equations splitLeqForward {A : Set} `{Lattice A} (e1 : lexp A) (e2 : lexp A) : Prop
  by wf (lexp_size e1 + lexp_size e2) lt :=
  splitLeqForward (Var a1) (Var a2) => leq_lat a1 a2;
  splitLeqForward (Join e11 e12) e2 => splitLeqForward e11 e2 /\ splitLeqForward e12 e2;
  splitLeqForward e1 (Meet e21 e22) => splitLeqForward e1 e21 /\ splitLeqForward e1 e22;
  splitLeqForward e1 e2 => leq_lat (denoteLexp e1) (denoteLexp e2).

From Coq Require Import ssreflect.

Definition leq_lat' {A : Set} {_:Lattice A} (e1 e2 : A) := join e1 e2 = e2.

Lemma leq_lat_leq_lat'_iff {A : Set} {_:Lattice A} :
  forall e1 e2, e1 ≤ e2 <-> leq_lat' e1 e2.
Proof.
  strivial
    use: @join_commutative,
      @meet_absorptive,
      @meet_commutative,
      @join_absorptive unfold: leq_lat', leq_lat inv: Lattice.
Qed.

Lemma leq_meet_iff {A : Set} {_:Lattice A} (e1 e2 e3 : A) :
  leq_lat e1 (meet e2 e3) <-> leq_lat e1 e2 /\ leq_lat e1 e3.
Proof.
  qauto depth: 4 l: on use: @meet_commutative, @join_absorptive
                       unfold: leq_lat, meet, join inv: Lattice.
Qed.

Lemma leq_join_iff {A : Set} {_:Lattice A} (e2 e3 e1 : A) :
  leq_lat (join e2 e3) e1 <-> leq_lat e2 e1 /\ leq_lat e3 e1.
Proof.
  rewrite !leq_lat_leq_lat'_iff !/leq_lat'.
  split.
  - move => H1.
    split.
    + rewrite -H1
       {1}join_associative
       {1}join_associative
       [join e3 _]join_commutative
       -[join e2 _]join_associative
          join_idempotent
      //.
    + rewrite -H1
       -join_associative
       [join e3 _]join_commutative
       [join _ e3]join_associative
       join_idempotent
      //.
  - move => [H1 H2].
    rewrite join_associative H2 H1 //.
Qed.

(* The other direction is not true.... *)
Lemma leq_join_prime {A : Set} {_:Lattice A} (e1 e2 e3 : A) :
  leq_lat e1 e2 \/ leq_lat e1 e3 -> leq_lat e1 (join e2 e3).
Proof.
  rewrite !leq_lat_leq_lat'_iff !/leq_lat'.
  sauto lq: on use: join_associative.
Qed.

Lemma leq_meet_prime {A : Set} {_:Lattice A} (e1 e2 e3 : A) :
  leq_lat e1 e3 \/ leq_lat e2 e3 -> leq_lat (meet e1 e2) e3.
Proof.
  hfcrush l: on q: on use: meet_associative, meet_commutative.
Qed.

(* I don't understand why, but we do need @ for typeclass rewrite rules *)
#[export] Hint Rewrite -> @leq_meet_iff @leq_join_iff : lat_db_rew.
#[export] Hint Resolve leq_join_prime leq_meet_prime : lat_db.

(* Transforming goal *)
Theorem splitLeq_sound {A : Set} {H:Lattice A} (e1 e2 : lexp A) :
  splitLeq e1 e2 -> leq_lat (denoteLexp e1) (denoteLexp e2).
Proof.
  intros.
  have h0 := splitLeq_graph_correct _ H e1 e2.
  remember (splitLeq e1 e2) as p.
  induction h0 using splitLeq_graph_rect;
    hauto lq: on rew: off db: lat_db rew:db: lat_db_rew.
Qed.

Theorem splitLeq_complete {A : Set} {H:Lattice A} (e1 e2 : lexp A) :
  leq_lat (denoteLexp e1) (denoteLexp e2) -> splitLeq e1 e2.
Proof.
  intros.
  have h0 := splitLeq_graph_correct _ H e1 e2.
  remember (splitLeq e1 e2) as p.
  induction h0 using splitLeq_graph_rect;
    hauto lq: on rew: off db: lat_db rew:db: lat_db_rew.
Qed.

Theorem splitLeqForward_complete {A : Set} {H:Lattice A} (e1 e2 : lexp A) :
  leq_lat (denoteLexp e1) (denoteLexp e2) -> splitLeqForward e1 e2.
Proof.
  move => H0.
  have h0 := splitLeqForward_graph_correct _ H e1 e2.
  remember (splitLeqForward e1 e2) as p.
  induction h0 using splitLeqForward_graph_rect;
    hauto db: lat_db rew:db:lat_db_rew.
Qed.

Theorem splitLeq_iff {A : Set} {H:Lattice A} (e1 e2 : lexp A) :
  leq_lat (denoteLexp e1) (denoteLexp e2) <-> splitLeq e1 e2.
Proof.
  hauto depth:1 use: @splitLeq_sound, @splitLeq_complete.
Qed.

Ltac2 rec reify_lexp (e : constr) :=
  lazy_match! e with
  | meet ?a1 ?a2 =>
    let e1 := reify_lexp a1 in
    let e2 := reify_lexp a2 in
    '(Meet $e1 $e2)
  | join ?a1 ?a2 =>
    let e1 := reify_lexp a1 in
    let e2 := reify_lexp a2 in
    '(Join $e1 $e2)
  | ?e => '(Var $e)
  end.

(* takes as input a hypothesis' identifier and type; erase the hypothesis if it's not relevant to lattices *)
Ltac2 simplify_lattice_hyp (id : ident) (ty : constr) : unit :=
  simpl in $id;
  lazy_match! ty with
  | leq_lat ?a1 ?a2 =>
      let e1 := reify_lexp a1 in
      let e2 := reify_lexp a2 in
      apply (splitLeqForward_complete $e1 $e2) in $id;
      ltac1:(h1 |- simp splitLeqForward in h1) (Ltac1.of_ident id);
      simpl in $id
  (* TODO: keep the equalities about lattices *)
  | _ => clear id
  end.

Ltac2 simplify_lattice_hyps () : unit :=
  (* iterate through the list of hypotheses *)
  List.iter
    (fun (id, _, ty) =>
       simplify_lattice_hyp id ty)
    (Control.hyps ()).

Ltac2 simplify_lattice_goal () : unit :=
  simpl; intros;
  lazy_match! goal with
  | [|- leq_lat ?a1 ?a2] =>
    let e1 := reify_lexp a1 in
    let e2 := reify_lexp a2 in
    apply (splitLeq_sound $e1 $e2); ltac1:(simp splitLeq)
  | [|- _] =>
      ltac1:(exfalso)
  end.

(* TODO: parameterize solve_lattice by a base case tactic for handling the leaves? *)
Ltac2 solve_lattice () :=
  simplify_lattice_goal ();
  simplify_lattice_hyps ().

Ltac2 Notation "solve_lattice" := solve_lattice ().
