How can I combine two scalaz streams with a predicate selector?

185 Views Asked by At

I would like to combine two scalaz streams with a predicate which selects the next element from either stream. For instance, I would like this test to pass:

val a = Process(1, 2, 5, 8)
val b = Process(3, 4, 5, 7)

choose(a, b)(_ < _).toList shouldEqual List(1, 2, 3, 4, 5, 5, 7, 8)

As you can see, we can't do something clever like zip and order the two elements because one of the processes may be selected consecutively at times.

I took a stab at a solution that I thought would work. It compiled! But damn it if it doesn't do anything. The JVM just hangs :(

import scalaz.stream.Process._
import scalaz.stream._

object StreamStuff {
  def choose[F[_], I](a:Process[F, I], b:Process[F, I])(p: (I, I) => Boolean): Process[F, I] =
    (a.awaitOption zip b.awaitOption).flatMap {
      case (Some(ai), Some(bi)) =>
        if(p(ai, bi)) emit(ai) ++ choose(a, emit(bi) ++ b)(p)
        else emit(bi) ++ choose(emit(ai) ++ a, b)(p)
      case (None, Some(bi)) => emit(bi) ++ b
      case (Some(ai), None) => emit(ai) ++ a
      case _ => halt
    }
}

Note that the above was my second attempt. In my first attempt I tried to create a Tee but I couldn't figure out how to un-consume the loser element. I felt that I needed something recursive like I have here.

I am using streams version 0.7.3a.

Any tips (including incremental hints because I'd like to simply learn how to figure these things out on my own) are greatly appreciated!!

2

There are 2 best solutions below

0
Travis Brown On

I'll give a couple of hints and an implementation below, so you might want to cover the screen if you want to work out a solution yourself.

Disclaimer: this is just the first approach that came to mind, and my familiarity with the scalaz-stream API is a little rusty, so there may be nicer ways to implement this operation, this one might be totally wrong in some horrible way, etc.

Hint 1

Instead of trying to "unconsume" the losing elements, you can pass them along in the next recursive call.

Hint 2

You can avoid having to accumulate more than one losing element by indicating which side lost last.

Hint 3

I often find it easier to sketch out an implementation using ordinary collections first when I'm working with Scalaz streams. Here's the helper method we'll need for lists:

/**
 * @param p if true, the first of the pair wins
 */
def mergeListsWithHeld[A](p: (A, A) => Boolean)(held: Either[A, A])(
  ls: List[A],
  rs: List[A]
): List[A] = held match {
  // Right is the current winner.
  case Left(l) => rs match {
    // ...but it's empty.
    case Nil => l :: ls
    // ...and it's still winning.
    case r :: rt if p(r, l) => r :: mergeListsWithHeld(p)(held)(ls, rt)
    // ...upset!
    case r :: rt => l :: mergeListsWithHeld(p)(Right(r))(ls, rt)
  }
  // Left is the current winner.
  case Right(r) => ls match {
    case Nil => r :: rs
    case l :: lt if p(l, r) => l :: mergeListsWithHeld(p)(held)(lt, rs)
    case l :: lt => r :: mergeListsWithHeld(p)(Left(l))(lt, rs)
  }
}

That assumes we've already got a losing element in hand, but now we can write the method we actually want to use:

def mergeListsWith[A](p: (A, A) => Boolean)(ls: List[A], rs: List[A]): List[A] =
  ls match {
    case Nil => rs
    case l :: lt => rs match {
      case Nil => ls
      case r :: rt if p(l, r) => l :: mergeListsWithHeld(p)(Right(r))(lt, rt)
      case r :: rt            => r :: mergeListsWithHeld(p)(Left(l))(lt, rt)
    }
  }

And then:

scala> org.scalacheck.Prop.forAll { (ls: List[Int], rs: List[Int]) =>
     |   mergeListsWith[Int](_ < _)(ls.sorted, rs.sorted) == (ls ++ rs).sorted
     | }.check
+ OK, passed 100 tests.

Okay, looks fine. There are nicer ways we could write this for lists, but this implementation matches the shape of what we'll need to do for Process.

Implementation

And here's more or less the equivalent with scalaz-stream:

import scalaz.{ -\/, \/, \/- }
import scalaz.stream.Process.{ awaitL, awaitR, emit }
import scalaz.stream.{ Process, Tee, tee }

def mergeWithHeld[A](p: (A, A) => Boolean)(held: A \/ A): Tee[A, A, A] =
  held.fold(_ => awaitR[A], _ => awaitL[A]).awaitOption.flatMap {
    case None =>
      emit(held.merge) ++ held.fold(_ => tee.passL, _ => tee.passR)
    case Some(next) if p(next, held.merge) =>
      emit(next) ++ mergeWithHeld(p)(held)
    case Some(next) =>
      emit(held.merge) ++ mergeWithHeld(p)(
        held.fold(_ => \/-(next), _ => -\/(next))
      )
  }

def mergeWith[A](p: (A, A) => Boolean): Tee[A, A, A] =
  awaitL[A].awaitOption.flatMap {
    case None => tee.passR
    case Some(l) => awaitR[A].awaitOption.flatMap {
      case None =>               emit(l) ++ tee.passL
      case Some(r) if p(l, r) => emit(l) ++ mergeWithHeld(p)(\/-(r))
      case Some(r)            => emit(r) ++ mergeWithHeld(p)(-\/(l))
    }
  }

And lets check it again:

scala> org.scalacheck.Prop.forAll { (ls: List[Int], rs: List[Int]) =>
     |   Process.emitAll(ls.sorted).tee(Process.emitAll(rs.sorted))(
     |     mergeWith(_ < _)
     |   ).toList == (ls ++ rs).sorted
     | }.check
+ OK, passed 100 tests.

I wouldn't put this into production without some more testing, but it looks like it works.

0
Captain O. On

You have to implement a custom tee, as Travis Brown suggested. Here is my implementation of the tee:

/*
  A tee which sequentially compares elements from left and right
  and passes an element from left if predicate returns true, otherwise
  passes an element from right.
 */
def predicateTee[A](predicate: (A, A) => Boolean): Tee[A, A, A] = {

  def go(stack: Option[A \/ A]): Tee[A, A, A] = {
    def stackEither(l: A, r: A) =
      if (predicate(l, r)) emit(l) ++ go(\/-(r).some) else emit(r) ++ go(-\/(l).some)

    stack match {
      case None =>
        awaitL[A].awaitOption.flatMap { lo =>
          awaitR[A].awaitOption.flatMap { ro =>
            (lo, ro) match {
              case (Some(l), Some(r)) => stackEither(l, r)
              case (Some(l), None) => emit(l) ++ passL
              case (None, Some(r)) => emit(r) ++ passR
              case _ => halt
            }
          }
        }
      case Some(-\/(l)) => awaitR[A].awaitOption.flatMap {
        case Some(r) => stackEither(l, r)
        case None => emit(l) ++ passL
      }
      case Some(\/-(r)) => awaitL[A].awaitOption.flatMap {
        case Some(l) => stackEither(l, r)
        case None => emit(r) ++ passR
      }
    }
  }

  go(None)
}

val p1: Process[Task, Int] = Process(1, 2, 4, 5, 9, 10, 11)
val p2: Process[Task, Int] = Process(0, 3, 7, 8, 6)

p1.tee(p2)(predicateTee(_ < _)).runLog.run
//res0: IndexedSeq[Int] = Vector(0, 1, 2, 3, 4, 5, 7, 8, 6, 9, 10, 11)