理解 Scala 中的 for 表达式

Scala 中的 for 表达式非常强大,用起来很简单顺手,但是理解起来可能需要一些背景知识。这个 for 从名字上看让我们觉得它很像命令式语言中的那种 for statement,但是其实 Scala 的 for 是非常函数式的,它跟命令式语言中的 for 语句根本不是一个东西,因为它根本不支持 break 和 continue 这些语句。

官方文档提供的例子:

val twentySomethings = for (user <- userBase if (user.age >=20 && user.age < 30))
  yield user.name  // i.e. add this to a list

Scala 的 for 表达式倒是很像 Python 和 Haskell 里面的 list comprehension,其实这个 for 表达式基本上可以等同于 Haskell 的 list comprehension 加上 do annotation 了。然而 Haskell 里面 list comprehension 和 do 也是一个东西,所以基本等同于 for comprehension。

在理解 for 之前,需要先了解 mapflatMap,这些概念在 Scala 和 Haskell 中很相似:

Map

map 大多数人会用在 List 上面,但是更广义的概念上,map 可以用在任何带有 context 的类型上面,不仅限于 List:

// map 通过一个函数,可以把 A 元素的列表转换成列表 B 的
map(f: A => B): List[A] => List[B]

同理,map 不仅可以用在 List 上面,也可以用在类似 Option 和 Future 这样的类型上面。从直观意义上理解,map 同样可以把 Option[A] 转换成 Option[B],把 Future[A] 转换成 Future[B]

在 Hasekll 里面,这种支持 map 操作的类型,叫做 Functor,Scala 里面似乎没有一个名字。

FlatMap

flatMap 理解起来则稍微复杂一些,看一下 List 里面的 flatMap:

def flatMap[U](f: T => List[U]): List[U]

这里面的 f 不再像 map 里面返回一个新的子类型,而是,返回一个带有 context 类型本身: List[U]。也就是说 f 函数返回的新类型是包含在 context 里面的。需要注意的是在这里,返回的结果并不是 List[List[U]],而是仅仅是 List[U] 。通过 flatMap 这个名字中的 flat 我们也知道,最后的结果会是把每个 list 都拼接在一起,也就是所谓的拍平了。

上面提到的是 List 的 flatMap,同样地,Future 类型同样也是支持 flatMap 的,它的的参数 f 类型我们应该也猜到了:

def flatMap[U](f: T => Future[U]): Future[U]

通过 flatMap 我们就可以做很多事情了。这个时候假设我们有函数 fetch1fetch2 ,都是返回 Future[Response],如果我们要先 fetch1 然后根据 1 的 response 来 fetch2 ,然后把 2 的结果打印出来,我们要怎么做呢:

// 写法一
val resp2 = fetch1
.flatMap { resp1 =>
    // do something
    fetch2(resp1.data) // 通过 1 的结果来返回新的 Future
}
.map { resp2 =>
    printf(resp2.data)
    resp2.data
}

然后回到开头,我们可以用 for comprehension 来完成这个事情:

// 写法二
val resp2 = for {
    resp1 <- fetch1()
    resp2 <- fetch2(resp1.data)
    printf(resp2.data)
} yield resp2.data

其实实际上,Scala 会把上述代码翻译成与下面等价的样子:

// 写法三
val resp2 = resp1
.flatMap { resp1 =>
    fetch2(resp2.data)
    .map {
        printf(resp2.data)
    	resp2.data
    }
}

只要你仔细观察,你就会发现,在这个例子里面,这三种写法的结果都是一样的,但是第一种写法有什么缺点呢,缺点就是在第二个 flatMap 里面,resp1 已经不在作用域里面了,这个时候如果你还想调用 resp1 已经不可能了,但是写法三是可以的:

// 写法三
val resp2 = resp1
.flatMap { resp1 =>
    fetch2(resp2.data)
    .map {
        printf(resp1.data) // 在这里仍然可以调用 resp1,因为它还在作用域里面
        printf(resp2.data)
    	resp2.data
    }
}

这样看起来就跟 Haskell 的 Monad 很接近了。

For 表达式的几种翻译方式

其实明白了 map 和 flatMap 之后,for 表达式很容易理解了。但是实现起来还要参考 for 表达式的翻译方式

If

值得注意的是 if 语法:

for(x <- c; if cond) yield {...}

会被翻译成:

c.withFilter(x => cond).map(x => {...})

当然会有 fallback (如果不支持 withFilter):

c.filter(x => cond).map(x => {...})

这时候再看文章开头的例子:

val twentySomethings = for (user <- userBase if (user.age >=20 && user.age < 30))
  yield user.name  // i.e. add this to a list

会被翻译成:

val twentySomethings = userBase
.filter { user =>
    user.age >=20 && user.age < 30
}
.map { _.name }

就很容易理解了

赋值语句

for 表达式中同样支持赋值语句(这里说的是 = 不是 <-),这在处理异步请求的时候非常常见:

val resp2 = for {
    resp1 <- fetch1()
    data = doSomething(resp1)
    resp2 <- fetch2(data)
    printf(resp2.data)
} yield resp2.data

注意,在 for comprehension 中写 = 赋值语句并不需要写 var 或者 val ,至于为什么,只要看它会被翻译成什么样子就能明白:

val resp2 = for {
    resp1 <- fetch1()
    data = doSomething(resp1)
    resp2 <- fetch2(data)
    printf(resp2.data)
} yield resp2.data

var resp2 = fetch1()
.map { resp1 =>
    (resp1, doSomething(resp1))
}
.flatMap { 
    case (resp1, data) =>
    	fetch2(data)
    .map { resp2 =>
        printf(resp2.data)
        resp2.data
    }
}

所以这个时候再写 valvar 已经没什么意义了

CPS 变换

其实 Scala 做的这个变换是一种 CPS(Continuation Passing Style) 变换,把一系列看上去想是赋值 <- 的操作转换成 CPS 的形式,这有什么好处呢,通过这种变换,map 和 flatMap 的参数 f 就可以是一个纯(Pure)的函数,也就是说通过一系列纯函数的组合,可以实现像命令式编程里面的有状态的赋值操作,而这些操作是含有上下文(context)的。

另一个好处是,像下面这种调用:

c1.map(p1 => p1.map(p2 => p2.map(p3 => ...)))

其实是尾递归调用,编译器可以为此做很多优化,提高运行速度。

总结

for comprehension 其实是一个很甜很好用的语法糖,可以帮我们省下很多代码。我们当然可以选择继续用 flatMapmap,但是 for 就能用很简单的语言写出来。当然,简单的背后是有代价的,用 for 的过程中就需要使用者明白表达式最终会被怎样转换成相应的 flatMapmap 代码,这样才能让我们写出简洁的 for comprehension。

这个时候,我真的很佩服 Scala 的设计,让整个语言各个部分的设计都非常容易和优雅地组合,设计得相当地精妙,同时我在其中也看到了一些 Haskell 的味道。