{-# LANGUAGE NoMonomorphismRestriction #-}
module ProductMixAuction.FujishigeWolfe
  ( BitMask
  , SubmodularFunction
  , findMin
  , fromBitMask
    -- primarily for testing
  , edmondsGreedy
  , mkPoint
  , submodularityCheck
  )
  where

import Data.List (foldl', sortBy)
import Data.Ord (comparing)
import qualified Data.Vector as Vec
import qualified Data.Vector.Storable as V
import Numeric.LinearAlgebra
-- import Debug.Trace

type Point = Vector Double

mkPoint :: [Double] -> Point
mkPoint = fromList

-- | This is intended to be the main entry point.
--
-- It takes the dimension of the problem (i.e., the
-- number of good we are considering) and the submodular
-- function itself.
--
-- It returns the minimum norm point we find.
--
-- TODO: the submodular function currently
-- has to evaluate to 0 on the empty set. This
-- is easy to fix by adapting the function,
-- and currently has to be done externally.
--
findMin :: Int -> SubmodularFunction -> BitMask
findMin n sf =
  let
    p = minNormPoint
          (edmondsGreedy sf (V.generate n (const 1)))
          (edmondsGreedy sf)
  in
    Vec.generate n (\ i -> p V.! i < -epsilon)
  where
    epsilon :: Double
    epsilon = 1e-8 -- relatively tolerant


-- | Find the minmal norm point of the given polytope.
--
-- The function takes an initial point that is a vertex of the
-- polytope.
--
-- We also take a linear optimization oracle for the polytope.
-- This is the only info (next to the initial vertex) that we
-- actually need about the given polytope.
--
-- The step numbers and identifiers in the algorithm follow the
-- original Wolfe description.
--
minNormPoint ::
     Point            -- ^ initial point
  -> (Point -> Point) -- ^ linear optimization oracle for polytope
  -> Point
minNormPoint initial linOpt =

  -- Step 0. Initialisation.
  --
  -- The algorithm maintains a set of points S on the polytope,
  -- and a linear combination of weights w associated with each
  -- of these points.
  --
  -- Computing the linear combination induced by w and S yields
  -- the point X.

  let
    w = [1]
    s = [initial]
    x = initial
  in
    step1 w s x

  where
    step1 :: [Double] -> [Point] -> Point -> Point
    -- step1 w s x | traceShow (1, w, s, x) False = undefined
    step1 w s x =
      let
        -- Step 1a. We don't have to compute x because we
        -- already have it.
        --
        -- Step 1b. We compute J via the linear optimisation
        -- oracle.
        --
        j :: Point
        j = linOpt x
      in
        -- Step 1c. We check if we've reached optimality.
        if dot x j > dot x x - z1
          then x -- We're done.
          else
            if j `elem` s
              -- Step 1d. We check if the point is already in our
              -- set.
              --
              -- Wolfe says that this typically indicates that the
              -- Step 1c has almost succeeded.
              --
              then x -- emergency stop
              else
                -- Step 1e. We initialise Step 2 by updating S and
                -- w with the new point.
                --
                let
                  s' = s ++ [j]
                  w' = w ++ [0]
                in
                  step2 w' s' x

    step2 :: [Double] -> [Point] -> Point -> Point
    -- step2 w s x | traceShow (2, w, s, x) False = undefined
    step2 w s x =
      let
        -- Step 2a. This is where we find the affine minimum.
        -- We currently do so naively.
        --
        -- Our affineMin function returns both the coefficients
        -- v and the point y. The Wolfe paper only talks about v.
        --
        y :: Point
        v :: [Double]
        (y, v) = affineMin s
      in
        -- Step 2b. We check if all the coefficients in v are
        -- positive. If so, we can replace w by v and go back
        -- to Step 1, because y is in fact in the convex hull
        -- of our current set s. If not, we go to Step 3.
        --
        if all (> z2) v
          then
            let
              w' = v
              x' = y
            in
              step1 w' s x'
          else
            -- We keep all relevant data around for Step 3.
            step3 w s x v y

    step3 :: [Double] -> [Point] -> Point -> [Double] -> Point -> Point
    -- step3 w s x v y | traceShow (3, w, s, x, v, y) False = undefined
    step3 w s _x v _y =
      -- Note that above, we currently discard points x and y. We
      -- could compute the new point via theta below and the old
      -- points. However, we decide to compute the new weights using
      -- theta and simply recompute the point from these.
      let
        -- Step 3a. According to the Wolfe paper, we compute the set
        -- POS of indices for which w - v is positive.
        --
        -- Note that the coefficients in w are all non-negative.
        -- And because we're here, we know at least one of the v
        -- coefficients is negative, so at least one of the w - v
        -- values is greater than w.
        --
        -- We ultimately want the minimal (w / (w - v)), and in
        -- fact, we want it to be at most 1.

        wvPairs :: [(Double, Double)]
        wvPairs = zip w v

        wPosPairs :: [(Double, Double)]
        wPosPairs = filter (\ (w_, v_) -> v_ <= z2 && w_ - v_ > z3) wvPairs

        -- Step 3b. Compute theta which should be between 0 and 1.
        --
        theta :: Double
        theta = minimum (1 : map (\ (w_, v_) -> w_ / (w_ - v_)) wPosPairs)

        -- Step 3c. Compute a new w.
        --
        -- I think the order of w and v is erroneously swapped in the
        -- Wolfe paper here. I've applied the order given in the
        -- Chakrabarty paper.
        --
        w' :: [Double]
        w' = map (\ (w_, v_) -> theta * v_ + (1 - theta) * w_) wvPairs

        -- Step 3d. Replace by 0 all sufficiently small elements in w'.
        -- Step 3e. Remove all points in S corresponding to 0 coefficients.
        --
        -- TODO: Note that the Wolfe paper says we should remove just one
        -- element from S. It's not entirely clear if removing more is
        -- harmful, and could e.g. affect termination.
        --
        w'' :: [Double]
        s' :: [Point]
        (w'', s') = unzip (filter (\ (w_, _s_) -> w_ > z2) (zip w' s))

        x' :: Point
        x' = linComb w'' s'
      in
        -- traceShow ("theta", theta, w', w'') $
        step2 w'' s' x'

    -- Thresholds as suggested in the Wolfe paper.
    z1 :: Double
    z1 = 1e-5 -- 1e-12

    z2 :: Double
    z2 = 1e-10

    z3 :: Double
    z3 = 1e-10

-- | Compute the linear combination of coefficients and equally many points.
--
linComb :: [Double] -> [Point] -> Point
linComb cs ps =
  foldl' (+) 0 (zipWith scale cs ps)

type SubmodularFunction = BitMask -> Integer
type BitMask = Vec.Vector Bool

-- | This algorithm is supposed to be used as the linear
-- optimisation oracle in the main algorithm.
--
-- It is for example described in the thesis by Garcia,
-- as Algorithm 2.
--
-- The algorithm can also be used to obtain an initial
-- point for the main algorithm.
--
edmondsGreedy :: SubmodularFunction -> Point -> Point
edmondsGreedy sf w =
  let
    n = V.length w
    perm = sortIndex w
    -- Here, we obtain the indices of the given point in
    -- order of ascending weigths. We then take the subsets
    -- corresponding to the initial sets of these ascending
    -- weights.
    masks :: [BitMask]
    masks = map (toBitMask n) (vecInits perm)
  in
    let
      p =
        (   fromIntegral
        .   (\ (m1, m2) -> sf m2 - sf m1)
        <$> adjacentPairs masks
        )
    in
      -- We have to permute this into the correct order.
      -- This was originally missing, then I thought I
      -- could use backpermute, but this is not backpermute.
      -- We need a forward permutation.
      -- We have to backpermute into the correct order.
      -- This was missing originally, and hence caused
      -- incorrect results.
      vector (map snd (sortBy (comparing fst) (zip (toList perm) p)))
      -- cmap (\ i -> p ! fromIntegral i) perm

-- | This algorithm is for computing the affine minimum,
-- in a very naive and inefficient way.
--
-- There are various ways of improving this computation,
-- some described by Wolfe, others being discussed in the
-- thesis by Garcia.
--
affineMin :: [Point] -> (Point, [Double])
affineMin s =
  let
    -- sizes: let n be the dimension of a point, and s be
    -- the number of points
    --
    -- n x s
    b :: Matrix Double
    b = fromColumns s

    _cs :: Int
    _cs = length s

    -- s x s
    btb :: Matrix Double
    btb = tr b <> b

    -- (s + 1) x (s + 1)
    m :: Matrix Double
    m = fromBlocks [ [ 1, btb], [0, 1] ]

    _m' :: Matrix Double
    _m' = fromBlocks [ [ 0, 1], [1, btb] ]

    -- (s + 1) x (s + 1)
    e :: Matrix Double
    e = inv m

    _e' :: Matrix Double
    _e' = inv _m'

    -- s x 1
    alpha :: Matrix Double
    alpha = tr (e ?? (Take 1, DropLast 1))

    -- n x 1
    y :: Matrix Double
    y = b <> alpha
  in
    -- -- traceShow ("affminE", e, e', e_) $
    -- traceShow ("affmin", b, alpha, y) $
    (head $ toColumns y, concat $ toLists $ alpha)

vecInits :: V.Storable a => Vector a -> [Vector a]
vecInits v =
  map (flip V.take v) [0 .. V.length v]

-- TODO: Quite suboptimal
toBitMask :: Int -> Vector I -> BitMask
toBitMask n x =
  Vec.generate n (flip V.elem x . fromIntegral)

fromBitMask :: Num a => BitMask -> Vec.Vector a
fromBitMask = Vec.map (\ b -> if b then 1 else 0)

adjacentPairs :: [a] -> [(a, a)]
adjacentPairs (x : y : ys) = (x, y) : adjacentPairs (y : ys)
adjacentPairs _            = []

-- Utility functions

allSets :: Int -> [[Bool]]
allSets 0 = [[]]
allSets n = [ x : xs | x <- [False, True], xs <- allSets (n - 1) ]

submodularityCheck :: Int -> SubmodularFunction -> Bool
submodularityCheck n sf =
  and
    [ sf' x + sf' y >= sf' xOrY + sf' xAndY | x <- allSets n, y <- allSets n, let xOrY = zipWith (||) x y, let xAndY = zipWith (&&) x y ]
  where
    sf' = sf . Vec.fromList

-- Local test cases

_test :: SubmodularFunction
_test bv = case (bv Vec.! 0, bv Vec.! 1) of
  (False, False) -> 0
  (False, True ) -> -2
  (True,  False) -> 2
  (True,  True ) -> -1

_test2 :: SubmodularFunction
_test2 bv = case (bv Vec.! 0, bv Vec.! 1) of
  (False, False) -> 0
  (False, True ) -> 0
  (True,  False) -> 0
  (True,  True ) -> -1

_test3 :: SubmodularFunction
_test3 bv = case (bv Vec.! 0, bv Vec.! 1) of
  (False, False) -> 0
  (False, True ) -> 1
  (True,  False) -> 2
  (True,  True ) -> -2