How to write generic Monad law tests?

78 Views Asked by At

Given the definitions:

trait Functor[F[_]]:
  extension [A](fa: F[A]) def map[B](f: A => B): F[B]

trait Monad[F[_]] extends Functor[F]:
  def unit[A](a: => A): F[A]

  extension [A](fa: F[A])
    def flatMap[B](f: A => F[B]): F[B]

    def map[B](f: A => B): F[B] =
      flatMap(f.andThen(unit))

object MonadSyntax:
  // 'fA' is the thing on which the method >>= is invoked.
  extension [F[_], A](fA: F[A])(using mA: Monad[F])
    def >>=[B](f: A => F[B]): F[B] = mA.flatMap(fA)(f)

I want to write tests using ScalaTest and ScalaCheck to verify the three Monad laws. Instead of repeating the tests for every Monad instance, I'd like to make the tests generic. My attempt is as follows:

trait MonadLaws[F[_]] { this: AnyFunSpec & ScalaCheckPropertyChecks =>
  // Monad[F].unit(x).flatMap(f) === f(x)
  def leftIdentity[A: Arbitrary, B: Arbitrary](
    using arbAToFB: Arbitrary[A => F[B]], 
    ma: Monad[F]
    ): Unit =
    it("should satisfy the left identity law"):
      forAll { (x: A, f: A => F[B]) =>
        (ma.unit(x) >>= f) shouldBe f(x)
      }

  // m.flatMap(Monad[F].unit) === m
  def rightIdentity[A: Arbitrary](using eqM: Equality[Monad[F]], ma: Monad[F]): Unit =
    it("should satisfy the right identity law"):
        val left = for
          a: A <- ma
        yield summon[Monad[F]].unit(a)

      left shouldBe ma

  // m.flatMap(f).flatMap(g) === m.flatMap { x => f(x).flatMap(g) }
  def associativityLaw[A: Arbitrary, B: Arbitrary, C: Arbitrary](
    using arbAToFB: Arbitrary[A => F[B]],
    arbBToFB: Arbitrary[B => F[C]],
    ma: Monad[F],
    eqM: Equality[Monad[F]]
  ): Unit =
    it("should satisfy the associativity law"):
      forAll { (f: A => F[B], g: B => F[C]) =>
        ma.flatMap(f).flatMap(g) shouldBe m.flatMap(x => f(x).flatMap(g))
      }
}

Then a particular Monad instance would run as tests as below:

class OptionMonadLawsSpec extends AnyFunSpec with ScalaCheckPropertyChecks with MonadLaws[Option]:
  describe("Options form a Monad"):
    import MonadInstances.optionMonad

    leftIdentity[Int, Int]

Imports have been omitted for brevity.

Problem is the second and third tests don't compile, both failing with:

Found:    A => F[A]
Required: F[A²]

where:    A  is a type in method rightIdentity
          A² is a type variable with constraint 
          F  is a type in trait MonadLaws with bounds <: [_] =>> Any

I've racked my brain for more than an hour and can't get the types right. Also, how do I invoke Monad[F].unit where unit is an instance method?

1

There are 1 best solutions below

0
Abhijit Sarkar On BEST ANSWER

@LuisMiguelMejíaSuárez helped me on Typelevel Discord channel to get the code compiling. I've yet to test it thoroughly, but the types check out, and he was also kind enough to point out a basic misunderstanding.

trait MonadLaws[F[_]] { this: AnyFunSpec & ScalaCheckPropertyChecks =>
  // Monad[F].unit(x).flatMap(f) === f(x)
  def leftIdentity[A, B](using
      Monad[F],
      Arbitrary[A],
      Arbitrary[A => F[B]],
      Equality[F[B]]
  ): Unit =
    it("should satisfy the left identity law"):
      forAll { (a: A, f: A => F[B]) =>
        val lhs = summon[Monad[F]].unit(a) >>= f
        val rhs = f(a)

        lhs shouldBe rhs
      }

  // m.flatMap(Monad[F].unit) === m
  def rightIdentity[A, B](using Monad[F], Arbitrary[F[A]], Equality[F[A]]): Unit =
    it("should satisfy the right identity law"):
      forAll { (fa: F[A]) =>
        val lhs = fa >>= summon[Monad[F]].unit
        val rhs = fa

        lhs shouldBe rhs
      }

  // m.flatMap(f).flatMap(g) === m.flatMap { x => f(x).flatMap(g) }
  def associativityLaw[A, B, C](using
      Monad[F],
      Arbitrary[F[A]],
      Arbitrary[A => F[B]],
      Arbitrary[B => F[C]],
      Equality[F[C]]
  ): Unit =
    it("should satisfy the associativity law"):
      forAll { (fa: F[A], f: A => F[B], g: B => F[C]) =>
        val lhs = fa >>= f >>= g
        val rhs = fa >>= (a => f(a) >>= (g))

        lhs shouldBe rhs
      }
}