-- | Exhaustive auction rates search through
--   hyperplane intersections
module ProductMixAuction.BudgetConstraints.Intersect where

import Control.Lens
import Data.Coerce
import Data.List
import Data.Maybe
import Numeric.LinearAlgebra
import qualified Data.Set as Set
import qualified Data.Vector as V
import qualified Data.Vector as VG
import qualified Data.Vector.Storable as VS
import ProductMixAuction.BudgetConstraints.Types

-- | Return all the candidate auction rates for the given
--   bids.
allCandidateRates :: Eq b => AuctionKind -> [Bid b] -> Set.Set PriceVector
allCandidateRates _ [] = mempty
allCandidateRates ak bids@(b:_) = Set.fromList . map (coerce . cv)
                                $ allIntersections ak n pairs

  where n = V.length (b ^. bid_prices)
        pairs = [ P i (_Price p) (bid ^. bid_label)
                | bid <- bids
                , (i, p) <- zip [0..] (V.toList $ bid ^. bid_prices)
                ]

        cv :: VS.Vector Double -> V.Vector Double
        cv = VG.convert

-- | Given all the @(i, p_i)@ pairs of good number and price,
--   compute the resulting hyperplanes (both hods and flanges)
--   and find all the intersections of all choices of @n@ of
--   them.
allIntersections :: Eq b => AuctionKind -> Int -> [GoodPrice b] -> [Vector Double]
allIntersections ak n
  = roundPoints ak
  . solveSystems
  . allIntersectionSystems ak n
  . allHyperplanes

-- | Return all the hyperplanes generated by the given
--   list of @(i, p_i)@ pairs, both hods and flanges.
--
--   The resulting list will then be used to pick @dim@
--   hyperplanes at a time and consider their intersection
--   by solving the corresponding linear system.
allHyperplanes :: Eq b => [GoodPrice b] -> [HyperplaneType b]
allHyperplanes goodprices = hods ++ flanges

  where hods = [ Hod (P i p_i bid_i)
               | P i p_i bid_i <- goodprices
               ]

        flanges = [ Flange (P i p_i bid_i) (P j p_j bid_j)
                  | P i p_i bid_i <- goodprices
                  , P j p_j bid_j <- goodprices
                  , bid_i == bid_j
                  , i /= j
                  ]

-- | Generate all the linear systems resulting from the intersection of
--   the given number of hyperplanes taken from the given list.
allIntersectionSystems :: AuctionKind -> Int -> [HyperplaneType b] -> [(Matrix Double, Vector Double)]
allIntersectionSystems ak n hs =
  [ hyperplanesSystem $ map (toHyperplane ak n) hs'
  | hs' <- select n hs
  , atLeastOneHod hs'
  ]

  where atLeastOneHod [] = False
        atLeastOneHod (Hod _ : _) = True
        atLeastOneHod (_ : hts) = atLeastOneHod hts

-- | map '(<\>)' over the given systems to get solutions
solveSystems :: [(Matrix Double, Vector Double)] -> [Vector Double]
solveSystems = catMaybes . map (uncurry solve)

  where solve mat vec
          | fst (size mat) /= snd (size mat) =
              error $ "solveSystems.solve: non square matrix " ++ show (size mat)
          | rank mat < fst (size mat) =
              -- non invertible matrix => no unique non-zero solution
              Nothing
          | otherwise = Just (mat <\> vec)

-- | For standard auctions, all intersections should be non-negative
-- integers, so round the points appropriately and remove any that are
-- negative.
--
-- For budget-constrained auctions, remove points that are too close
-- to each other, and round extremely small values to zero.
roundPoints :: AuctionKind -> [Vector Double] -> [Vector Double]
roundPoints ak = case ak of
  Standard          -> filter (VS.all (>=0)) . map (VS.map (fromInteger . round))
  BudgetConstrained -> approxZeros . removeSimilarPoints

-- | Replace any extremely small value with a genuine zero.
approxZeros :: [Vector Double] -> [Vector Double]
approxZeros = map (VS.map (\ x -> if abs x < epsilon then 0 else x))

-- | Rounding error may result in multiple points that are very close
-- together, so we remove all but one point in each case.
removeSimilarPoints :: [Vector Double] -> [Vector Double]
removeSimilarPoints = nubBy tooClose . sort
  where tooClose u v = norm_2 (u - v) < epsilon

-- | @select n xs@ returns all the possible combinations of @n@ elements
--   of @xs@, modulo permutations (it doesn't return @[1,2]@ and @[2, 1]@
--   but rather just one of them, as in our case we don't care in which order
--   we put the hyperplane in the matrix, the system is equivalent whatever
--   the "order of the equations" is).
select :: Int -> [a] -> [[a]]
select 0 _ = [[]]
select n l = do
  (x:xs) <- tails l
  rest   <- select (n - 1) xs
  return (x:rest)

-- | @'H' as b@ corresponds to the hyperplane with equation
--   @a_0*x_0 + ... + a_(n-1)*x_(n-1) = b@.
data Hyperplane = H (Vector Double) Double

ppHyperplanes :: [Hyperplane] -> String
ppHyperplanes = intercalate "\n" . map ppHyperplane

ppHyperplane :: Hyperplane -> String
ppHyperplane (H as b) = intercalate " + " coeffs ++ " = " ++ show b

  where coeffs = catMaybes
          [ if a_i == 0 then Nothing else Just (smartShow a_i $ "x_" ++ show i)
          | (i, a_i) <- zip [0::Int ..] (toList as)
          ]

        smartShow d xstr
          | d == 1.0  = xstr
          | otherwise = show d ++ "*" ++ xstr

-- | Return @(A, b)@ where @Ax = b@ is the linear system corresponding
--   to the intersection of all the given hyperplanes.
hyperplanesSystem :: [Hyperplane] -> (Matrix Double, Vector Double)
hyperplanesSystem hs =
  ( fromRows [ as | H as _ <- hs ] -- the A in Ax = b
  , fromList [ b | H _ b <- hs ]   -- the b in Ax = b
  )

-- | @(i, p_i)@ pair
data GoodPrice b = P !Int !Double b
  deriving (Eq, Show)

-- | @'Hod' i p@ corresponds to the hyperplane with equation
--   @x_i = p@.
--
--   @'Flange' i p_i j p_j@ corresponds to the hyperplane with
--   equation @x_i/p_i = x_j/p_j@ (for budget-constrained auctions),
--   or @x_i-p_i = x_j-p_j@ (for standard auctions).
data HyperplaneType b
  = Hod !(GoodPrice b)
  | Flange !(GoodPrice b) !(GoodPrice b)
  deriving (Eq, Show)

-- | Build a 'Hyperplane' of the given type in a space of
--   the given dimension.
toHyperplane :: AuctionKind -> Int -> HyperplaneType b -> Hyperplane
toHyperplane ak n t = case t of
  Hod (P i p _b) ->
    H (fromList [ if k == i then 1 else 0 | k <- [0 .. n-1] ]) p

  Flange (P i p_i _b_i) (P j p_j _b_j)
    | i /= j ->
        let (as, v) = case ak of
                        Standard          -> ([ if k == i then 1     else if k == j then -1     else 0 | k <- [0..n-1] ], p_i - p_j)
                        BudgetConstrained -> ([ if k == i then 1/p_i else if k == j then -1/p_j else 0 | k <- [0..n-1] ], 0)
        in H (fromList as) v
    | otherwise -> error $ "toHyperplane.Flange: i = j = " ++ show i

-- | A very small floating point. @< epsilon@ will be
--   used instead of @== 0@.
epsilon :: RealFloat a => a
epsilon = 10 ** (-6)