Scala 3 collection partitioning with subtypes

427 Views Asked by At

In Scala 3, let's say I have a List[Try[String]]. Can I split it up into success and failures, such that each list has the appropriate subtype?

If I do the following:

import scala.util.{Try, Success, Failure}
val tries = List(Success("1"), Failure(Exception("2")))
val (successes, failures) = tries.partition(_.isSuccess)

then successes and failures are still of type List[Try[String]]. The same goes if I filter based on the type:

val successes = tries.filter(_.isInstanceOf[Success[String]])

I could of course cast to Success and Failure respectively, but is there a type-safe way to achieve this?

1

There are 1 best solutions below

6
Dmytro Mitin On BEST ANSWER

@Luis Miguel Mejía Suárez:

Use tries.partitionMap(_.toEither)

@mitchus:

@LuisMiguelMejíaSuárez ok the trick here is that Try has a toEither method which splits to the proper type. What if we have a regular sealed trait?

In Scala 2 I would do something like

import shapeless.{:+:, ::, CNil, Coproduct, Generic, HList, HNil, Inl, Inr, Poly0}
import shapeless.ops.coproduct.ToHList
import shapeless.ops.hlist.{FillWith, Mapped, Tupler}

trait Loop[C <: Coproduct, L <: HList] {
  def apply(c: C, l: L): L
}
object Loop {
  implicit def recur[H, CT <: Coproduct, HT <: HList](implicit
    loop: Loop[CT, HT]
  ): Loop[H :+: CT, List[H] :: HT] = {
    case (Inl(h), hs :: ht) => (h :: hs) :: ht
    case (Inr(ct), hs :: ht) => hs :: loop(ct, ht)
  }

  implicit val base: Loop[CNil, HNil] = (_, l) => l
}

object nilPoly extends Poly0 {
  implicit def cse[A]: Case0[List[A]] = at(Nil)
}

def partition[A, C <: Coproduct, L <: HList, L1 <: HList](as: List[A])(implicit
  generic: Generic.Aux[A, C],
  toHList: ToHList.Aux[C, L],
  mapped: Mapped.Aux[L, List, L1],
  loop: Loop[C, L1],
  fillWith: FillWith[nilPoly.type, L1],
  tupler: Tupler[L1]
): tupler.Out = {
  val partitionHList: L1 = as.foldRight(fillWith())((a, l1) =>
    loop(generic.to(a), l1)
  )

  tupler(partitionHList)
}

sealed trait A
case class B(i: Int) extends A
case class C(i: Int) extends A
case class D(i: Int) extends A

partition(List[A](B(1), B(2), C(1), C(2), D(1), D(2), B(3), C(3))) 
// (List(B(1), B(2), B(3)),List(C(1), C(2), C(3)),List(D(1), D(2))): (List[B], List[C], List[D])

https://scastie.scala-lang.org/DmytroMitin/uQp603sXT7WFYmYntDXmIw/1


I managed to translate this code into Scala 3 although the translation turned to be wordy (I remplemented Generic and Coproduct)

import scala.annotation.tailrec
import scala.deriving.Mirror

object App1 {
  // ============= Generic =====================
  trait Generic[T] {
    type Repr
    def to(t: T): Repr
    def from(r: Repr): T
  }
  object Generic {
    type Aux[T, Repr0] = Generic[T] { type Repr = Repr0 }
    def instance[T, Repr0](f: T => Repr0, g: Repr0 => T): Aux[T, Repr0] =
      new Generic[T] {
        override type Repr = Repr0
        override def to(t: T): Repr0 = f(t)
        override def from(r: Repr0): T = g(r)
      }

    object ops {
      extension [A](a: A) {
        def toRepr(using g: Generic[A]): g.Repr = g.to(a)
      }

      extension [Repr](a: Repr) {
        def to[A](using g: Generic.Aux[A, Repr]): A = g.from(a)
      }
    }

    given [T <: Product](using
      m: Mirror.ProductOf[T]
    ): Aux[T, m.MirroredElemTypes] = instance(
      _.productIterator
       .foldRight[Tuple](EmptyTuple)(_ *: _)
       .asInstanceOf[m.MirroredElemTypes],
      m.fromProduct(_).asInstanceOf[T]
    )

    inline given [T, C <: Coproduct](using
      m: Mirror.SumOf[T],
      ev: Coproduct.ToCoproduct[m.MirroredElemTypes] =:= C
    ): Generic.Aux[T, C] =
      instance(
        matchExpr[T, C](_).asInstanceOf[C],
        Coproduct.unsafeFromCoproduct(_).asInstanceOf[T]
      )

    import scala.quoted.*

    inline def matchExpr[T, C <: Coproduct](ident: T): Coproduct =
      ${matchExprImpl[T, C]('ident)}

    def matchExprImpl[T: Type, C <: Coproduct : Type](
      ident: Expr[T]
    )(using Quotes): Expr[Coproduct] = {
      import quotes.reflect.*

      def unwrapCoproduct(typeRepr: TypeRepr): List[TypeRepr] = typeRepr match {
        case AppliedType(_, List(typ1, typ2)) => typ1 :: unwrapCoproduct(typ2)
        case _  => Nil
      }

      val typeReprs = unwrapCoproduct(TypeRepr.of[C])

      val methodIdent =
        Ident(TermRef(TypeRepr.of[Coproduct.type], "unsafeToCoproduct"))

      def caseDefs(ident: Term): List[CaseDef] =
        typeReprs.zipWithIndex.map { (typeRepr, i) =>
          CaseDef(
            Typed(ident, Inferred(typeRepr) /*TypeIdent(typeRepr.typeSymbol)*/),
            None,
            Block(
              Nil,
              Apply(
                methodIdent,
                List(Literal(IntConstant(i)), ident)
              )
            )
          )
        }

      def matchTerm(ident: Term): Term = Match(ident, caseDefs(ident))

      matchTerm(ident.asTerm).asExprOf[Coproduct]
    }
  }

  // ============= Coproduct =====================
  sealed trait Coproduct extends Product with Serializable
  sealed trait +:[+H, +T <: Coproduct] extends Coproduct
  final case class Inl[+H, +T <: Coproduct](head: H) extends (H +: T)
  final case class Inr[+H, +T <: Coproduct](tail: T) extends (H +: T)
  sealed trait CNil extends Coproduct

  object Coproduct {
    def unsafeToCoproduct(length: Int, value: Any): Coproduct =
      (0 until length).foldLeft[Coproduct](Inl(value))((c, _) => Inr(c))

    @tailrec
    def unsafeFromCoproduct(c: Coproduct): Any = c match {
      case Inl(h) => h
      case Inr(c) => unsafeFromCoproduct(c)
      case _: CNil => sys.error("impossible")
    }

    type ToCoproduct[T <: Tuple] <: Coproduct = T match {
      case EmptyTuple => CNil
      case h *: t => h +: ToCoproduct[t]
    }

//    type ToTuple[C <: Coproduct] <: Tuple = C match {
//      case CNil => EmptyTuple
//      case h +: t => h *: ToTuple[t]
//    }

    trait ToTuple[C <: Coproduct] {
      type Out <: Tuple
    }
    object ToTuple {
      type Aux[C <: Coproduct, Out0 <: Tuple] = ToTuple[C] { type Out = Out0 }
      def instance[C <: Coproduct, Out0 <: Tuple]: Aux[C, Out0] =
        new ToTuple[C] { override type Out = Out0 }

      given [H, T <: Coproduct](using 
        toTuple: ToTuple[T]
      ): Aux[H +: T, H *: toTuple.Out] = instance
      given Aux[CNil, EmptyTuple] = instance
    }
  }
}

// different file
import App1.{+:, CNil, Coproduct, Generic, Inl, Inr}

object App2 {    
  trait Loop[C <: Coproduct, L <: Tuple] {
    def apply(c: C, l: L): L
  }
  object Loop {
    given [H, CT <: Coproduct, HT <: Tuple](using 
      loop: Loop[CT, HT]
    ): Loop[H +: CT, List[H] *: HT] = {
      case (Inl(h), hs *: ht) => (h :: hs) *: ht
      case (Inr(ct), hs *: ht) => hs *: loop(ct, ht)
    }

    given Loop[CNil, EmptyTuple] = (_, l) => l
  }

  trait FillWithNil[L <: Tuple] {
    def apply(): L
  }
  object FillWithNil {
    given [H, T <: Tuple](using 
      fillWithNil: FillWithNil[T]
    ): FillWithNil[List[H] *: T] = () => Nil *: fillWithNil()
    given FillWithNil[EmptyTuple] = () => EmptyTuple
  }

  def partition[A, /*L <: Tuple,*/ L1 <: Tuple](as: List[A])(using
    generic: Generic.Aux[A, _ <: Coproduct],
    toTuple: Coproduct.ToTuple[generic.Repr],
    //ev0: Coproduct.ToTuple[generic.Repr] =:= L, // compile-time NPE
    ev: Tuple.Map[toTuple.Out/*L*/, List] =:= L1,
    loop: Loop[generic.Repr, L1],
    fillWith: FillWithNil[L1]
  ): L1 = as.foldRight(fillWith())((a, l1) => loop(generic.to(a), l1))

  sealed trait A
  case class B(i: Int) extends A
  case class C(i: Int) extends A
  case class D(i: Int) extends A

  def main(args: Array[String]): Unit = {
    println(partition(List[A](B(1), B(2), C(1), C(2), D(1), D(2), B(3), C(3))))
  // (List(B(1), B(2), B(3)),List(C(1), C(2), C(3)),List(D(1), D(2)))
  }
}

Scala 3.0.2


In the macro (generating pattern matching) Inferred(typeRepr) should be instead of TypeIdent(typeRepr.typeSymbol), otherwise this doesn't work for parametric case classes. Actually the macro can be removed at all if we use mirror.ordinal. Simplified version is

import scala.deriving.Mirror
import scala.util.NotGiven

trait Generic[T] {
  type Repr
  def to(t: T): Repr
  def from(r: Repr): T
}

object Generic {
  type Aux[T, Repr0] = Generic[T] {type Repr = Repr0}

  def instance[T, Repr0](f: T => Repr0, g: Repr0 => T): Aux[T, Repr0] =
    new Generic[T] {
      override type Repr = Repr0
      override def to(t: T): Repr0 = f(t)
      override def from(r: Repr0): T = g(r)
    }

  object ops {
    extension[A] (a: A) {
      def toRepr(using g: Generic[A]): g.Repr = g.to(a)
    }

    extension[Repr] (a: Repr) {
      def to[A](using g: Generic.Aux[A, Repr]): A = g.from(a)
    }
  }

  given [T <: Product](using
    // ev: NotGiven[T <:< Tuple],
    // ev1: NotGiven[T <:< Coproduct],
    m: Mirror.ProductOf[T],
    m1: Mirror.ProductOf[m.MirroredElemTypes]
  ): Aux[T, m.MirroredElemTypes] = instance(
    m1.fromProduct(_),
    m.fromProduct(_)
  )

  given[T, C <: Coproduct](using
    // ev: NotGiven[T <:< Tuple],
    // ev1: NotGiven[T <:< Coproduct],
    m: Mirror.SumOf[T],
    ev2: Coproduct.ToCoproduct[m.MirroredElemTypes] =:= C
  ): Generic.Aux[T, C/*Coproduct.ToCoproduct[m.MirroredElemTypes]*/] = {
    instance(
      t => Coproduct.unsafeToCoproduct(m.ordinal(t), t).asInstanceOf[C],
      Coproduct.unsafeFromCoproduct(_).asInstanceOf[T]
    )
  }
}
sealed trait Coproduct extends Product with Serializable
sealed trait +:[+H, +T <: Coproduct] extends Coproduct
final case class Inl[+H, +T <: Coproduct](head: H) extends (H +: T)
final case class Inr[+H, +T <: Coproduct](tail: T) extends (H +: T)
sealed trait CNil extends Coproduct

object Coproduct {
  def unsafeToCoproduct(length: Int, value: Any): Coproduct =
    (0 until length).foldLeft[Coproduct](Inl(value))((c, _) => Inr(c))

  @scala.annotation.tailrec
  def unsafeFromCoproduct(c: Coproduct): Any = c match {
    case Inl(h) => h
    case Inr(c) => unsafeFromCoproduct(c)
    case _: CNil => sys.error("impossible")
  }

  type ToCoproduct[T <: Tuple] <: Coproduct = T match {
    case EmptyTuple => CNil
    case h *: t => h +: ToCoproduct[t]
  }

  type ToTuple[C <: Coproduct] <: Tuple = C match {
    case CNil => EmptyTuple
    case h +: t => h *: ToTuple[t]
  }
}

Replacing type classes with compile-time/inline methods and match types

import scala.compiletime.erasedValue

inline def loop[C <: Coproduct, L <: Tuple](c: C, l: L): L = (inline erasedValue[C] match {
  case _: CNil => inline erasedValue[L] match {
    case _: EmptyTuple => EmptyTuple
  }
  case _: (h +: ct) => inline erasedValue[L] match {
    case _: (List[`h`] *: ht) => (c, l) match {
      case (Inl(h_v: `h`), (hs_v: List[`h`]) *: (ht_v: `ht`)) => 
        (h_v :: hs_v) *: ht_v
      case (Inr(ct_v: `ct`), (hs_v: List[`h`]) *: (ht_v: `ht`)) => 
        hs_v *: loop[ct, ht](ct_v, ht_v)
    }
  }
}).asInstanceOf[L]

inline def fillWithNil[L <: Tuple]: L = (inline erasedValue[L] match {
  case _: EmptyTuple => EmptyTuple
  case _: (List[h] *: t) => Nil *: fillWithNil[t]
}).asInstanceOf[L]

type TupleList[C <: Coproduct] = Tuple.Map[Coproduct.ToTuple[C], List]

inline def partition[A](as: List[A])(using
  generic: Generic.Aux[A, _ <: Coproduct]
): TupleList[generic.Repr] =
  as.foldRight(fillWithNil[TupleList[generic.Repr]])((a, l1) => loop(generic.to(a), l1))
sealed trait A
case class B(i: Int) extends A
case class C(i: Int) extends A
case class D(i: Int) extends A

@main def test = {
  println(partition(List[A](B(1), B(2), C(1), C(2), D(1), D(2), B(3), C(3))))
  // (List(B(1), B(2), B(3)),List(C(1), C(2), C(3)),List(D(1), D(2)))
}

Tested in 3.2.0 https://scastie.scala-lang.org/DmytroMitin/940QaiqDQQ2QegCyxTbEIQ/1

How to access parameter list of case class in a dotty macro


Alternative implementation of loop

//Loop[C, L] = L
type Loop[C <: Coproduct, L <: Tuple] <: Tuple = C match {
  case CNil    => CNilLoop[L]
  case h +: ct => CConsLoop[h, ct, L]
}
// match types seem not to support nested type matching
type CNilLoop[L <: Tuple] <: Tuple = L match {
  case EmptyTuple => EmptyTuple
}
type CConsLoop[H, CT <: Coproduct, L <: Tuple] <: Tuple = L match {
  case List[H] *: ht => List[H] *: Loop[CT, ht]
}
/*inline*/ def loop0[C <: Coproduct, L <: Tuple](c: C, l: L): Loop[C, L] = /*inline*/ c match {
  case _: CNil => /*inline*/ l match {
    case _: EmptyTuple => EmptyTuple
  }
  case c: (h +: ct) => /*inline*/ l match {
    case l: (List[`h`] *: ht) => (c, l) match {
      case (Inl(h_v/*: `h`*/), (hs_v/*: List[`h`]*/) *: (ht_v/*: `ht`*/)) =>
        (h_v :: hs_v) *: ht_v.asInstanceOf[Loop[ct, ht]]
      case (Inr(ct_v/*: `ct`*/), (hs_v/*: List[`h`]*/) *: (ht_v/*: `ht`*/)) => 
        hs_v *: loop0[ct, ht](ct_v, ht_v)
    }
  }
}
/*inline*/ def loop[C <: Coproduct, L <: Tuple](c: C, l: L): L = loop0(c, l).asInstanceOf[L]

Another implementation for Scala 2: Split list of algebraic date type to lists of branches?