How I could use tail call optimisation on this combination function?

167 Views Asked by At

My exercise, that you can see here, says that I need to implement a recursive version of C(n, k).

That's my solution:

module LE1.Recursao.Combinacao where

combina :: Integral a => a -> a -> a
combina n k | k == 0         = 1
            | n == k         = 1
combina n k | k > 0 && n > k = (combina (n - 1) k) + (combina (n - 1) (k - 1))
combina _ _                  = 0

Now I want to create a tail-recursive version of this function, so I don't get stack overflows for large numbers and also calculate the combination quicker!

I'm new to tail call optimisation, but I did that in Elixir for the Fibonacci series:

defmodule Fibonacci do
  def run(n) when n < 0, do: :error
  def run(n), do: run n, 1, 0
  def run(0, _, result), do: result
  def run(n, next, result), do: run n - 1, next + result, next
end

I understand this Fibonacci code and I think that the combination algorithm isn't too different, but I don't know how to start.

3

There are 3 best solutions below

0
dfeuer On BEST ANSWER

Daniel Wagner writes

combina n k = product [n-k+1 .. n] `div` product [1 .. k]

This can get rather inefficient for large n and medium-size k; the multiplications get huge. How might we keep them smaller? We can write a simple recurrence:

combina :: Integer -> Integer -> Integer
combina n k
  | k > n || k < 0 = 0
  | otherwise = combina' n k'
  where
    -- C(n,k) and C(n,n-k) are the same,
    -- so we choose the cheaper one.
    k' = min k (n-k)

-- Assumes 0 <= k <= n
combina' _n 0 = 1
combina' n k
  = -- as above
  product [n-k+1 .. n] `div` product [1 .. k]
  = -- expanding a bit
  (n-k+1) * product [n-k+2 .. n] `div` (product [1 .. k-1] * k)
  = -- Rearranging, and temporarily using real division
  ((n-k+1)/k) * (product [n-(k-1)+1 .. n] / product [1 .. k-1]
  = -- Daniel's definition
  ((n-k+1)/k) * combina' n (k-1)
  = -- Rearranging again to go back to integer division
  ((n-k+1) * combina' n (k-1)) `quot` k

Putting that all together,

combina' _n 0 = 1
combina' n k = ((n-k+1) * combina' n (k-1)) `quot` k

Just one problem remains: this definition is not tail recursive. We can fix that by counting up instead of down:

combina' n k0 = go 1 1
  where
    go acc k
      | k > k0 = n `seq` acc
      | otherwise = go ((n-k+1)*acc `quot` k) (k + 1)

Don't worry about the n `seq` ; it's of very little consequence.

Note that this implementation uses O(min(k,n-k)) arithmetic operations. So if k and n-k are both very large, it will take a long time. I don't know if there's any efficient way to get exact results in that situation; I believe that binomial coefficients of that sort are usually estimated rather than calculated precisely.

7
Daniel Wagner On

The traditional implementation looks like this:

combina n k = product [n-k+1 .. n] `div` product [1 .. k]

It runs instantly for just about as big a number as you care to put in, and fairly quickly for numbers bigger than you care to put in.

> :set +s
> combina 40000 20000
<answer snipped because it's way too big to be useful>
(0.74 secs, 837,904,032 bytes)

Meanwhile for your implementation, combina 30 15 took 7s and combina 40 20 took longer than I cared to wait -- multiple minutes, at least, even compiled with optimizations rather than interpreted.

It is possible to get still more efficient than this by choosing clever orders of multiplications and divisions, but it's significantly less readable.

2
dfeuer On

Daniel Wagner's answer (or my other answer, or similar things) is the right way in practice. But if you want to use the recurrence described in the problem set but want your function to run in a reasonable amount of time, you'll have to structure the solution rather differently. Consider what happens for some n and k not near a base case.

combina n k
  = combina (n - 1) k + combina (n - 1) (k - 1)
  =   combina (n - 2) k       + combina (n - 2) (k - 1)
    + combina (n - 2) (k - 1) + combina (n - 2) (k - 2)

See how combina (n - 2) (k - 1) is calculated twice here? This process repeats many times, leading to your algorithm taking exponential time. Ouch. How can you fix that? Well, imagine making a table of all the results of combina. Now you can walk along each row of the table to calculate the next one, avoiding the repeated work. This is really just calculating Pascal's triangle.

combina :: Int -> Int -> Integer
combina n k
  | k > n || k < 0 = 0
  | otherwise = build n !! n !! k

step :: [Integer] -> [Integer]
step xs = 1 : zipWith (+) xs (drop 1 xs) ++ [1]

build :: Int -> [[Integer]]
build n0 = go (n0 + 1) [1]
  where
    go 0 _ = []
    go n row = row : go (n - 1) (step row)

The key idea is in step, which calculates each row from the previous one.