Skip to main content

A First Look at Scala Macros

Posted by cayhorstmann on January 14, 2013 at 9:32 PM PST

The final version of Scala 2.10 was released on January 4, 2013. Martin Christensen, a visiting scholar in our department, and myself have been playing with some of the new features, and I'll be blogging about some of our discoveries in my copious spare time.

Today, I'll show you how to write a simple macro in Scala. You may have seen macros in C, such as #define swap(x, y) { int temp = x; x = y; y = temp; }

C macros are just text substitutions. If you call swap(first, last), the result is { int temp = first; first = last; last = temp; }, which is fine. But if you call swap(first, temp) or swap(a[i++], a[j]), bad things will happen. That's why C++ programmers are told to avoid macros and to use inline functions instead.

In Scheme, macros are much more commonly used. Here is a swap macro in Scheme:

(define-syntax-rule (swap x y)
  (let ([temp x]) (set! x y) (set! y temp)))

If you call(swap first last), then you get (let ([temp first]) (set! first last) (set! last temp)). This loooks just like #define in C, but there is an important difference. If you call (swap first temp), then the temp in the macro will automatically be renamed. Macros in Scheme are “hygienic”. In Scheme, it's easy to write macros, even complex ones, because Scheme programs are just lists.

Scala 2.10 has an experimental macro facility. It's not as easy as in Scheme, of course, because Scala programs aren't just lists. You have to know how to manipulate Scala parse trees.

First off, the mechanics. You define a macro as if it was a function, but then you use the macro keyword to link it to an actual function that works on parse trees.

  def swap(a: Any, b: Any): Unit = macro swap_impl
  def swap_impl(c: Context)(a: c.Expr[Any], b: c.Expr[Any]): c.Expr[Unit] = ...

This particular macro is willing to take as arguments expressions of any type, so the corresonding parse trees have type c.Expr[Any]. The “context” c encapsulates the compile-time services that are available in the macro implementation, such as type checking, logging errors, and many others.

Now we need to know how to construct a parse tree. Fortunately, there are some functions that let us display parse trees. For example, run this in the REPL:

import scala.reflect.runtime.universe._
print(showRaw(reify{println("Hello") }.tree))

The result is

Apply(Select(Select(This(newTypeName("scala")), 
newTermName("Predef")), newTermName("println")),
List(Literal(Constant("Hello"))))

If you squint at it, you can see how this is scala.Predef.println("Hello").

So, next I tried the following in the REPL:

var a = 3
var b = 4
print(showRaw(reify{var temp = a; a = b; b = temp }.tree))

I was rewarded with this beauty:

Block(List(ValDef(Modifiers(MUTABLE), 
newTermName("temp"), TypeTree(), Select(Select(Select(Select(Select(Select(Ident($line4),
newTermName("$read")), newTermName("$iw")), newTermName("$iw")),
newTermName("$iw")), newTermName("$iw")), newTermName("a"))),
Apply(Select(Select(Select(Select(Select(Select(Ident($line4), newTermName("$read")),
newTermName("$iw")), newTermName("$iw")), newTermName("$iw")),
newTermName("$iw")), newTermName("a_$eq")),
List(Select(Select(Select(Select(Select(Select(Ident($line8), newTermName("$read")),
newTermName("$iw")), newTermName("$iw")), newTermName("$iw")),
newTermName("$iw")), newTermName("b"))))),
Apply(Select(Select(Select(Select(Select(Select(Ident($line8), newTermName("$read")),
newTermName("$iw")), newTermName("$iw")), newTermName("$iw")),
newTermName("$iw")), newTermName("b_$eq")),
List(Ident(newTermName("temp")))))

Ugh. Apparently, when you write

var a = 3

in the REPL, you declare $line4.$read.$iw.$iw.$iw.a or some such thing. Let's try something simpler, defining the variables in a block.

print(showRaw(reify { var a = 3; var b = 4; { var temp = a; a = b; b = temp }}.tree))

That's better:

Block(List(ValDef(Modifiers(MUTABLE), newTermName("a"), TypeTree(),
Literal(Constant(3))), ValDef(Modifiers(MUTABLE), newTermName("b"), TypeTree(),
Literal(Constant(4)))), Block(List(ValDef(Modifiers(MUTABLE), newTermName("temp"),
TypeTree(), Ident(newTermName("a"))), Assign(Ident(newTermName("a")),
Ident(newTermName("b")))), Assign(Ident(newTermName("b")), Ident(newTermName("temp")))))

Now we can get going. We need to construct a

Block(List(ValDef(Modifiers(MUTABLE), newTermName("temp"), TypeTree(), Ident(...), Assign(Ident(...), Ident(...)))),
Assign(Ident(...), Ident(newTermName("temp")))))

You want to know how the macro is called so you know what expressions you receive. The print/showRaw/reify incantation works for that too:

def swap(x: Int, y: Int) {}
print(showRaw(reify{var a = 3; var b = 4; swap(a, b)}.tree))

yields

...Apply(Select(Select(Select(Select(Select(Select(Ident($line12), 
newTermName("$read")), newTermName("$iw")), newTermName("$iw")),
newTermName("$iw")), newTermName("$iw")), newTermName("swap")),
List(Ident(newTermName("a")), Ident(newTermName("b")))))

Again, the mysterious $read.$iw.$iw.$iw.swap—the list of $iw gets longer as you keep working in the REPL. But you can clearly see what you get: two expressions of the form Ident(new TermName("...")).

That's enough information to write the swap_impl method:

def swap_impl(c: Context)(a: c.Expr[Any], b: c.Expr[Any]): c.Expr[Unit] = {
  import c.universe._
  import c.universe.Flag._
  val unitResult =  c.Expr[Unit](Literal(Constant(())))
  a.tree match {
    case ia : Ident => b.tree match {
      case ib : Ident => c.Expr[Unit](Block( // Had to take out List
        ValDef(Modifiers(MUTABLE), newTermName("temp"), TypeTree(), ia),
        Assign(ia, ib),
        Assign(ib, Ident(newTermName("temp")))))
      case _ => unitResult
    }
    case _ => unitResult
  }
}

And it works:

object SwapTest extends App {
  import Swap._
  { // Need to define the variables in a block
    var a = 3
    var b = 4
    swap(a, b)
    println(a)
    println(b)
  } 
}

Exercise 1: What happens if you call swap(first, temp)? How can you fix it?

To my surprise, you get perfectly good error messages when you abuse the macro. For example, declare

var a = "Fred"

The error message says “type mismatch”. That makes sense. You can't swap a String and an Int.

Declare

val a = 3

and the error message says “reassignment to val”.

But not all is well yet. Try

var a = Array(3, 4)
swap(a(0), a(1))

Clearly, this can't work. Now the macro doesn't get an Ident(...). Recall that a(0) means a.apply(0). The macro gets an Apply(Select(..., newTermName("apply")), List(...)), as you can find out by using print/showRaw/reify.

Exercise 2: Do this. What happens? And what exactly is passed to the macro?

In this case, we want to call

int temp = a(0)
a(0) = a(1)
a(1) = temp

where the last two expressions are really

a.update(0, a.apply(1))
a.update(1, temp)

In other words, we are passed a tree containing a.apply(...), and we need to make a tree calling a.update(..., ...).

Finally, what if we do get fields of a class?

object SwapTest extends App {
  import Swap._
  var a = 3
  var b = 4
  swap(a, b)
  println(a)
  println(b)
}

Now, the call to swap receives the tree for the getter methods SwapTest.a and SwapTest.b.

Exercise 3: Verify this. What exactly is passed to the macro?

We want to generate calls to those getter methods for reading the value, and to the setter methods when writing:

int temp = SwapTest.a
SwapTest.a_$eq(SwapTest.b)
SwapTest.b_$eq(temp)

Here, a_$eq and b_$eq are the setter methods that are automatically generated for a var field. (If you defined them in Scala to replace the defaults, you'd call them a_= and b_=.)

To summarize, we'd like the swap macro to deal with three different kinds of arguments:

  • variables, passed as Ident(...)
  • locations in arrays or other collections, passed as Apply(Select(obj, "apply"), index)
  • mutable fields, passed as Apply(Select(obj, fieldname))

Here is the implementation of the macro that handles all three. The assign helper function deals with the three cases, turning them into an assignment, a call to update, or a call to the setter.

def swap_impl(c: Context)(a: c.Expr[Any], b: c.Expr[Any]): c.Expr[Unit] = {
  import c.universe._
  import c.universe.Flag._

  def assign(l: c.Expr[Any], r: c.Expr[Any]) = {
    l.tree match {
      case il : Ident => Assign(il, r.tree)
      case Apply(Select(obj, sel), List(index)) if sel.toString == "apply" =>
        Apply(Select(obj, newTermName("update")), List(index, r.tree))       
      case Select(obj, sel) => Apply(Select(obj, sel.toString + "_$eq"), List(r.tree))
      case _ => c.abort(l.tree.pos, "Expected variable or variable(index)")
    }
  }

  c.Expr[Unit](Block(
    ValDef(Modifiers(MUTABLE), newTermName("$temp"), TypeTree(), a.tree),
    assign(a, b),
    assign(b, c.Expr[Any](Ident(newTermName("$temp"))))))
}

Note how to report an error if none of the three cases occur. If you call swap(a + 1, b), then the tree won't match, and the abort method of the Context trait will cause an error report that points to the offending location.

Exercise 4: Call

val a = List(3, 4)
swap(a(0), b(0))

What error report do you get? Why?

Now you have seen a very basic macro, and you have seen that writing macros is no fun in Scala. But, to put it in perspective, it's a lot better than byte code engineering in Java.

But hope is on the way. The fellow who implemented all this, Eugene Burmako, is working on the macro paradise that should make common cases easier. Personally, I'll be happy when I can write

def swap_impl(c: Context)(a: c.Expr[Any], b: c.Expr[Any]) = {
  c.universe.reify { var temp = a; a = b; b = temp }
}