{-# OPTIONS --cubical --safe #-}

module TreeFold.Indexed where

open import Prelude
open import Data.Binary using (𝔹; 0ᡇ; 1ᡇ_; 2ᡇ_; ⟦_β‡“βŸ§; ⟦_β‡‘βŸ§)
open import Data.Binary.Increment using (inc)
open import Data.Binary.Properties.Isomorphism
open import Data.Nat

private
  variable
    n m : β„•
    t : Level
    N : β„• β†’ Type t
    ns : 𝔹

double : β„• β†’ β„•
double n = n * 2

2^_*_ : β„• β†’ β„• β†’ β„•
2^ zero  * m = m
2^ suc n * m = double (2^ n * m)

infixr 5 _1∷_ _2∷_
data Array (T : β„• β†’ Type a) : 𝔹 β†’ Type a where
  []  : Array T 0ᡇ
  _1∷_ : T 1 β†’ Array (T ∘ double) ns β†’ Array T (1ᡇ ns)
  _2∷_ : T 2 β†’ Array (T ∘ double) ns β†’ Array T (2ᡇ ns)

cons : (βˆ€ n β†’ N n β†’ N n β†’ N (double n)) β†’ N 1 β†’ Array N ns β†’ Array N (inc ns)
cons branch x [] = x 1∷ []
cons branch x (y 1∷ ys) = branch 1 x y 2∷ ys
cons branch x (y 2∷ ys) = x 1∷ cons (branch ∘ double) y ys

array-foldr : (N : β„• β†’ Type t) β†’ (βˆ€ n m β†’ N (2^ n * 1) β†’ N (2^ n * m) β†’ N (2^ n * suc m)) β†’ N 0 β†’ Array N ns β†’ N ⟦ ns β‡“βŸ§
array-foldr {ns = 0ᡇ}    N c b []        = b
array-foldr {ns = 2ᡇ ns} N c b (x 2∷ xs) = c 1 ⟦ ns β‡“βŸ§       x (array-foldr (N ∘ double) (c ∘ suc) b xs)
array-foldr {ns = 1ᡇ ns} N c b (x 1∷ xs) = c 0 (⟦ ns β‡“βŸ§ * 2) x (array-foldr (N ∘ double) (c ∘ suc) b xs)

open import Data.Vec
import Data.Nat.Properties as β„•

double≑*2 : βˆ€ n β†’ n + n ≑ n * 2
double≑*2 zero    = refl
double≑*2 (suc n) = cong suc (β„•.+-suc n n ΝΎ cong suc (double≑*2 n))

module NonNorm {t} (N : β„• β†’ Type t) (f : βˆ€ p n m β†’ N (2^ p * n) β†’ N (2^ p * m) β†’ N (2^ p * (n + m))) (z : N 0) where
  spine : Vec (N 1) n β†’ Array (N ) ⟦ n β‡‘βŸ§
  spine [] = []
  spine (x ∷ xs) = cons (Ξ» n x y β†’ subst N (double≑*2 n) (f 0 n n x y)) x (spine xs)

  unspine : Array N ns β†’ N ⟦ ns β‡“βŸ§
  unspine = array-foldr N (Ξ» n β†’ f n 1) z

  treeFold : Vec (N 1) n β†’ N n
  treeFold xs = subst N (ℕ→𝔹→ℕ _) (unspine (spine xs))

pow-suc : βˆ€ n m β†’ (2^ n * 1) + (2^ n * m) ≑ (2^ n * suc m)
pow-suc zero m = refl
pow-suc (suc n) m = sym (β„•.+-*-distrib (2^ n * 1) (2^ n * m) 2) ΝΎ cong (_* 2) (pow-suc n m)

module _ {t} (N : β„• β†’ Type t) (f : βˆ€ {n m} β†’ N n β†’ N m β†’ N (n + m)) (z : N 0) where
  spine : Vec (N 1) n β†’ Array (N ) ⟦ n β‡‘βŸ§
  spine [] = []
  spine (x ∷ xs) = cons (Ξ» n x y β†’ subst N (double≑*2 n) (f x y)) x (spine xs)

  unspine : Array N ns β†’ N ⟦ ns β‡“βŸ§
  unspine = array-foldr N (Ξ» n m xs ys β†’ subst N (pow-suc n m) (f xs ys)) z

  treeFold : Vec (N 1) n β†’ N n
  treeFold xs = subst N (ℕ→𝔹→ℕ _) (unspine (spine xs))