Spłaszczanie zagnieżdżonych list tego samego typu

Powiedzmy, że chcę spłaszczyć zagnieżdżone listy tego samego typu ... Na przykład

    ListA(Element(A), Element(B), ListA(Element(C), Element(D)), ListB(Element(E),Element(F)))

ListA zawiera zagnieżdżoną listę tego samego typu (ListA(Element(C), Element(D))) więc chcę go zastąpić wartościami, które zawiera, więc wynik górnego przykładu powinien wyglądać tak:

ListA(Element(A), Element(B), Element(C), Element(D), ListB(Element(E),Element(F)))

Bieżąca hierarchia klas:

abstract class SpecialList() extends Exp {
    val elements: List[Exp]
}

case class Element(name: String) extends Exp

case class ListA(elements: List[Exp]) extends SpecialList {
        override def toString(): String = "ListA("+elements.mkString(",")+")"
}

case class ListB(elements: List[Exp]) extends SpecialList {
        override def toString(): String = "ListB("+elements.mkString(",")+")"
}

object ListA{def apply(elements: Exp*):ListA = ListA(elements.toList)}
object ListB{def apply(elements: Exp*):ListB = ListB(elements.toList)}

Zrobiłem trzy rozwiązania, które działają, ale myślę, że musi być lepszy sposób na osiągnięcie tego:

Pierwsze rozwiązanie:

def flatten[T <: SpecialList](parentList: T): List[Exp] = {
        val buf = new ListBuffer[Exp]

        for (feature <- parentList.elements) feature match {
            case listA:ListA if parentList.isInstanceOf[ListA] => buf ++= listA.elements
            case listB:ListB if parentList.isInstanceOf[ListB] => buf ++= listB.elements
            case _ => buf += feature
        }
        buf.toList
    }

Drugie rozwiązanie:

def flatten[T <: SpecialList](parentList: T): List[Exp] = {
    val buf = new ListBuffer[Exp]

    parentList match {
        case listA:ListA => for (elem <- listA.elements) elem match {
                                case listOfTypeA:ListA => buf ++= listOfTypeA.elements
                                case _ => buf += elem
                            }

        case listB:ListB => for (elem <- listB.elements) elem match {
                                case listOfTypeB:ListB => buf ++= listOfTypeB.elements
                                case _ => buf += elem
                            }
    }

    buf.toList
}

Trzecie rozwiązanie

def flatten[T <: SpecialList](parentList: T): List[Exp] = parentList.elements flatMap {
    case listA:ListA if parentList.isInstanceOf[ListA] => listA.elements
    case listB:ListB if parentList.isInstanceOf[ListB] => listB.elements
    case other => List(other)
}

Moje pytanie brzmi: czy istnieje jakikolwiek lepszy, bardziej ogólny sposób na osiągnięcie takiej samej funkcjonalności, jak we wszystkich trzech wyższych rozwiązaniach jest powtórzenie kodu?