How to write a function defined for a subset of all enumeration values?

91 Views Asked by At

Suppose I've got enumeration A :

object A extends Enumeration { type A = Int; val A1 = 1; val A2 = 2; val A3 = 3 }

Also, I have a function defined only for A1 or A2 but not for A3.

def foo(a: A): Int = a match { 
  case A.A1 => 1 // do something 
  case A.A2 => 2 // do something else 
  case A.A3 => throw new UnsupportedOperationException() 
}

Now I would like to get a compilation error for foo(A.A3). In pseudocode I define foo like this:

def foo(a: A1 | A2): Int = ???

How would you suggest write foo to prevent calling it with A.A3 ?

1

There are 1 best solutions below

2
Erik van Oosten On BEST ANSWER

The given pseudo code is on the right track; the supported types need to be in the type signature of foo and unsupported types not.

// Scala 3

object A extends Enumeration {
  type A = Int
  val A1 = 1
  val A2 = 2
  val A3 = 3
}

def foo(a: A.A1.type | A.A2.type): Int = 1

@main
def main(): Unit = {
  foo(A.A1)
  foo(A.A3)
}

Compilation error:

Found:    (A.A3 : Int)
Required: (A.A1 : Int) | (A.A2 : Int)
  foo(A.A3)

To make this possible in scala 2 (literal types needed, so only scala 2.13 or Typelevel's 2.12) we need give a more refined type to the A1-A3 constants, and an implicit conversion to the 'union' type A12:

object A extends Enumeration {
  type A = Int
  val A1: 1 = 1
  val A2: 2 = 2
  val A3: 3 = 3

  sealed trait A12 { val a: A }
  private case class AnA12(a: A) extends A12
  implicit def a1ToA12(a1: 1): A12 = AnA12(a1)
  implicit def a2ToA12(a2: 2): A12 = AnA12(a2)
}

object Foo {
  import A._

  def foo(a: A12): Int = a.a

  def main(args: Array[String]): Unit = {
    foo(A1)
    foo(A3)
  }
}

Compilation fails with:

type mismatch;
 found   : 3
 required: Foo.A12
    foo(A3)

AnA12 is private to make it impossible to use foo(AnA12(A3)).