On this page:
15.1 Tail Calls
15.2 What is a Tail Call?
15.3 An Interpreter for Proper Calls
15.4 A Compiler with Proper Tail Calls
7.9

15 Jig: jumping to tail calls

    15.1 Tail Calls

    15.2 What is a Tail Call?

    15.3 An Interpreter for Proper Calls

    15.4 A Compiler with Proper Tail Calls

15.1 Tail Calls

With Iniquity, we’ve finally introduced some computational power via the mechanism of functions and function calls. Together with the notion of inductive data, which we have in the form of pairs, we can write fixed-sized programs that operate over arbitrarily large data.

The problem, however, is that there are a class of programs that should operate with a fixed amount of memory, but instead consume memory in proportion to the size of the data they operate on. This is unfortunate because a design flaw in our compiler now leads to asympototically bad run-times.

We can correct this problem by generating space-efficient code for function calls when those calls are in tail position.

Let’s call this language Jig.

There are no syntactic additions required: we simply will properly handling function calls.

15.2 What is a Tail Call?

A tail call is a function call that occurs in tail position. What is tail position and why is important to consider function calls made in this position?

Tail position captures the notion of “the last subexpression that needs to be computed.” If the whole program is some expression e, the e is in tail position. Computing e is the last thing (it’s the only thing!) the program needs to compute.

Let’s look at some examples to get a sense of the subexpressions in tail position. If e is in tail position and e is of the form:

The significance of tail position is relevant to the compilation of calls. Consider the compilation of a call as described in Iniquity: function definitions and calls: arguments are pushed on the call stack, then the 'call instruction is issued, which pushes the address of the return point on the stack and jumps to the called position. When the function returns, the return point is popped off the stack and jumped back to.

But if the call is in tail position, what else is there to do? Nothing. So after the call, return transfers back to the caller, who then just returns itself.

This leads to unconditional stack space consumption on every function call, even function calls that don’t need to consume space.

Consider this program:

; (Listof Number) -> Number
(define (sum xs) (sum/acc xs 0))
 
; (Listof Number) Number -> Number
(define (sum/acc xs a)
  (if (empty? xs)
      a
      (sum/acc (cdr xs) (+ (car xs) a))))

The sum/acc function should operate as efficiently as a loop that iterates over the elements of a list accumulating their sum. But, as currently compiled, the function will push stack frames for each call.

Matters become worse if we were re-write this program in a seemingly benign way to locally bind a variable:

; (Listof Number) Number -> Number
(define (sum/acc xs a)
  (if (empty? xs)
      a
      (let ((b (+ (car xs) a)))
        (sum/acc (cdr xs) b))))

Now the function pushes a return point and a local binding for b on every recursive call.

But we know that whatever the recursive call produces is the answer to the overall call to sum. There’s no need for a new return point and there’s no need to keep the local binding of b since there’s no way this program can depend on it after the recursive call. Instead of pushing a new, useless, return point, we should make the call with whatever the current return point. This is the idea of proper tail calls.

An axe to grind: the notion of proper tail calls is often referred to with misleading terminology such as tail call optimization or tail recursion. Optimization seems to imply it is a nice, but optional strategy for implementing function calls. Consequently, a large number of mainstream programming languages, most notably Java, do not properly implement tail calls. But a language without proper tail calls is fundamentally broken. It means that functions cannot reliably be designed to match the structure of the data they operate on. It means iteration cannot be expressed with function calls. There’s really no justification for it. It’s just broken. Similarly, it’s not about recursion alone (although it is critical for recursion), it really is about getting function calls, all calls, right. /rant

15.3 An Interpreter for Proper Calls

Before addressing the issue of compiling proper tail calls, let’s first think about the interpreter, starting from the interpreter we wrote for Iniquity:

iniquity/interp.rkt

  #lang racket
  (provide interp interp-env interp-prim1)
  (require "ast.rkt"
           "env.rkt"
           "interp-prims.rkt")
   
  ;; type Answer = Value | 'err
   
  ;; type Value =
  ;; | Integer
  ;; | Boolean
  ;; | Character
  ;; | Eof
  ;; | Void
  ;; | '()
  ;; | (cons Value Value)
  ;; | (box Value)
   
  ;; type REnv = (Listof (List Id Value))
  ;; type Defns = (Listof Defn)
   
  ;; Prog Defns -> Answer
  (define (interp p)
    (match p
      [(Prog ds e)
       (interp-env e '() ds)]))
   
  ;; Expr Env Defns -> Answer
  (define (interp-env e r ds)
    (match e
      [(Int i)  i]
      [(Bool b) b]
      [(Char c) c]
      [(Eof)    eof]
      [(Empty)  '()]
      [(Var x)  (lookup r x)]
      [(Prim0 'void) (void)]
      [(Prim0 'read-byte) (read-byte)]
      [(Prim0 'peek-byte) (peek-byte)]
      [(Prim1 p e)
       (match (interp-env e r ds)
         ['err 'err]
         [v (interp-prim1 p v)])]
      [(Prim2 p e1 e2)
       (match (interp-env e1 r ds)
         ['err 'err]
         [v1 (match (interp-env e2 r ds)
               ['err 'err]
               [v2 (interp-prim2 p v1 v2)])])]
      [(If p e1 e2)
       (match (interp-env p r ds)
         ['err 'err]
         [v
          (if v
              (interp-env e1 r ds)
              (interp-env e2 r ds))])]
      [(Begin e1 e2)
       (match (interp-env e1 r ds)
         ['err 'err]
         [_ (interp-env e2 r ds)])]
      [(Let x e1 e2)
       (match (interp-env e1 r ds)
         ['err 'err]
         [v (interp-env e2 (ext r x v) ds)])]
      [(App f es)
       (match (interp-env* es r ds)
         [(list vs ...)
          (match (defns-lookup ds f)
            [(Defn f xs e)
             ; check arity matches
             (if (= (length xs) (length vs))
                 (interp-env e (zip xs vs) ds)
                 'err)])]
         [_ 'err])]))
   
  ;; (Listof Expr) REnv Defns -> (Listof Value) | 'err
  (define (interp-env* es r ds)
    (match es
      ['() '()]
      [(cons e es)
       (match (interp-env e r ds)
         ['err 'err]
         [v (cons v (interp-env* es r ds))])]))
   
  ;; Defns Symbol -> Defn
  (define (defns-lookup ds f)
    (findf (match-lambda [(Defn g _ _) (eq? f g)])
           ds))
   
  (define (zip xs ys)
    (match* (xs ys)
      [('() '()) '()]
      [((cons x xs) (cons y ys))
       (cons (list x y)
             (zip xs ys))]))
   

What needs to be done to make it implement proper tail calls?

Well... not much. Notice how every Iniquity subexpression that is in tail position is interpreted by a call to interp-env that is itself in tail position in the Racket program!

So long as Racket implements tail calls properly, which is does, then this interpreter implements tail calls properly. The interpreter inherits the property of proper tail calls from the meta-language. This is but one reason to do tail calls correctly. Had we transliterated this program to Java, we’d be in trouble as the interpeter would inherit the lack of tail calls and we would have to re-write the interpreter, but as it is, we’re already done.

15.4 A Compiler with Proper Tail Calls

jig/compile.rkt

  #lang racket
  (provide (all-defined-out))
   
  (require "ast.rkt")
   
  ;; An immediate is anything ending in #b000
  ;; All other tags in mask #b111 are pointers
   
  (define result-shift     3)
  (define result-type-mask (sub1 (arithmetic-shift 1 result-shift)))
  (define type-imm         #b000)
  (define type-box         #b001)
  (define type-pair        #b010)
  (define type-string      #b011)
  (define type-proc        #b100)
   
  (define imm-shift        (+ 2 result-shift))
  (define imm-type-mask    (sub1 (arithmetic-shift 1 imm-shift)))
  (define imm-type-int     (arithmetic-shift #b00 result-shift))
  (define imm-type-bool    (arithmetic-shift #b01 result-shift))
  (define imm-type-char    (arithmetic-shift #b10 result-shift))
  (define imm-type-empty   (arithmetic-shift #b11 result-shift))
  (define imm-val-false    imm-type-bool)
  (define imm-val-true
    (bitwise-ior (arithmetic-shift 1 (add1 imm-shift)) imm-type-bool))
   
  ;; Allocate in 64-bit (8-byte) increments, so pointers
  ;; end in #b000 and we tag with #b001 for boxes, etc.
   
  ;; type CEnv = (Listof (Maybe Variable))
   
  ;; Prog -> Asm
  (define (compile p)
    (match p
      [(prog defs e)
       (let ((ds (compile-defines defs))
             (c0 (compile-entry e)))
         `(,@c0
           ,@ds))]))
   
  ;; Expr -> Asm
  ;; Compile e as the entry point
  (define (compile-entry e)
    `(entry
      ,@(compile-tail-e e '())
      ret
      err
      (push rbp)
      (call error)))
    
  ;; Expr CEnv -> Asm
  ;; Compile an expression in tail position
  (define (compile-tail-e e c)
    (match e
      [(var-e v)               (compile-variable v c)]
      [(? imm? i)              (compile-imm i)]
      [(prim-e (? prim? p) es) (compile-prim p es c)]
      [(if-e p t f)            (compile-tail-if p t f c)]
      [(let-e (list b) body)   (compile-tail-let b body c)]
      [(app-e f es)            (compile-tail-call f es c)]))
   
  ;; Expr CEnv -> Asm
  ;; Compile an expression in non-tail position
  (define (compile-e e c)
    (match e
      [(var-e v)               (compile-variable v c)]
      [(? imm? i)              (compile-imm i)]
      [(prim-e (? prim? p) es) (compile-prim p es c)]
      [(if-e p t f)            (compile-if p t f c)]
      [(let-e (list b) body)   (compile-let b body c)]
      [(app-e f es)            (compile-call f es c)]))
   
  ;; Our current set of primitive operations require no function calls,
  ;; so there's no difference between tail and non-tail call positions
  (define (compile-prim p es c)
    (match (cons p es)
      [`(box ,e0)            (compile-box e0 c)]
      [`(unbox ,e0)          (compile-unbox e0 c)]
      [`(cons ,e0 ,e1)       (compile-cons e0 e1 c)]
      [`(car ,e0)            (compile-car e0 c)]
      [`(cdr ,e0)            (compile-cdr e0 c)]
      [`(add1 ,e0)           (compile-add1 e0 c)]
      [`(sub1 ,e0)           (compile-sub1 e0 c)]
      [`(zero? ,e0)          (compile-zero? e0 c)]
      [`(empty? ,e0)         (compile-empty? e0 c)]
      [`(+ ,e0 ,e1)          (compile-+ e0 e1 c)]
      [_            (error
                      (format "prim applied to wrong number of args: ~a ~a" p es))]))
   
  ;; Variable (Listof Expr) CEnv -> Asm
  ;; Statically know the function we're calling
  (define (compile-call f es c)
    (let ((cs (compile-es es (cons #f c)))
          (stack-size (* 8 (length c))))
      `(,@cs
        (sub rsp ,stack-size)
        (call ,(symbol->label f))
        (add rsp ,stack-size))))
   
  ;; Variable (Listof Expr) CEnv -> Asm
  ;; Compile a call in tail position
  (define (compile-tail-call f es c)
    (let ((cs (compile-es es c)))
      `(,@cs
        ,@(move-args (length es) (- (length c)))
        (jmp ,(symbol->label f)))))
   
  ;; Integer Integer -> Asm
  ;; Move i arguments upward on stack by offset off
  (define (move-args i off)
    (match i
          [0 '()]
          [_ `(,@(move-args (sub1 i) off)
               (mov rbx (offset rsp ,(- off i)))
               (mov (offset rsp ,(- i)) rbx))]))
   
  ;; (Listof Expr) CEnv -> Asm
  (define (compile-es es c)
    (match es
      ['() '()]
      [(cons e es)
       (let ((c0 (compile-e e c))
             (cs (compile-es es (cons #f c))))
         `(,@c0
           (mov (offset rsp ,(- (add1 (length c)))) rax)
           ,@cs))]))
   
  ;; Variable (Listof Variable) Expr -> Asm
  (define (compile-define def)
    (match def
      [(fundef name args body)
        (let ((c0 (compile-e body (reverse args))))
          `(,(symbol->label name)
            ,@c0
            ret))]))
   
  ;; (Listof Variable) (Listof (Listof Variable)) (Listof Expr) -> Asm
  (define (compile-defines defs)
    (append-map compile-define defs))
   
  ;; Any -> Boolean
  (define (imm? x)
    (or (int-e? x)
        (bool-e? x)
        (char-e? x)
        (nil-e? x)))
   
  ;; Imm -> Asm
  (define (compile-imm i)
    `((mov rax ,(imm->bits i))))
   
  ;; Imm -> Integer
  (define (imm->bits i)
    (match i
      [(int-e i)  (arithmetic-shift i imm-shift)]
      [(char-e c) (+ (arithmetic-shift (char->integer c) imm-shift) imm-type-char)]
      [(bool-e b) (if b imm-val-true imm-val-false)]
      [(nil-e)    imm-type-empty]))
   
   
  ;; Variable CEnv -> Asm
  (define (compile-variable x c)
    (let ((i (lookup x c)))
      `((mov rax (offset rsp ,(- (add1 i)))))))
   
  ;; Expr CEnv -> Asm
  (define (compile-box e0 c)
    (let ((c0 (compile-e e0 c)))
      `(,@c0
        (mov (offset rdi 0) rax)
        (mov rax rdi)
        (or rax ,type-box)
        (add rdi 8)))) ; allocate 8 bytes
   
  ;; Expr CEnv -> Asm
  (define (compile-unbox e0 c)
    (let ((c0 (compile-e e0 c)))
      `(,@c0
        ,@assert-box
        (xor rax ,type-box)
        (mov rax (offset rax 0)))))
   
  ;; Expr Expr CEnv -> Asm
  (define (compile-cons e0 e1 c)
    (let ((c0 (compile-e e0 c))
          (c1 (compile-e e1 (cons #f c))))
      `(,@c0
        (mov (offset rsp ,(- (add1 (length c)))) rax)
        ,@c1
        (mov (offset rdi 0) rax)
        (mov rax (offset rsp ,(- (add1 (length c)))))
        (mov (offset rdi 1) rax)
        (mov rax rdi)
        (or rax ,type-pair)
        (add rdi 16))))
   
  ;; Expr CEnv -> Asm
  (define (compile-car e0 c)
    (let ((c0 (compile-e e0 c)))
      `(,@c0
        ,@assert-pair
        (xor rax ,type-pair)
        (mov rax (offset rax 1)))))
   
  ;; Expr CEnv -> Asm
  (define (compile-cdr e0 c)
    (let ((c0 (compile-e e0 c)))
      `(,@c0
        ,@assert-pair
        (xor rax ,type-pair)
        (mov rax (offset rax 0)))))
   
  ;; Expr CEnv -> Asm
  (define (compile-empty? e0 c)
    (let ((c0 (compile-e e0 c))
          (l0 (gensym)))
      `(,@c0
        (and rax ,imm-type-mask)
        (cmp rax ,imm-type-empty)
        (mov rax ,imm-val-false)
        (jne ,l0)
        (mov rax ,imm-val-true)
        ,l0)))
   
  ;; Expr CEnv -> Asm
  (define (compile-add1 e0 c)
    (let ((c0 (compile-e e0 c)))
      `(,@c0
        ,@assert-integer
        (add rax ,(arithmetic-shift 1 imm-shift)))))
   
  ;; Expr CEnv -> Asm
  (define (compile-sub1 e0 c)
    (let ((c0 (compile-e e0 c)))
      `(,@c0
        ,@assert-integer
        (sub rax ,(arithmetic-shift 1 imm-shift)))))
   
  ;; Expr CEnv -> Asm
  (define (compile-zero? e0 c)
    (let ((c0 (compile-e e0 c))
          (l0 (gensym))
          (l1 (gensym)))
      `(,@c0
        ,@assert-integer
        (cmp rax 0)
        (mov rax ,imm-val-false)
        (jne ,l0)
        (mov rax ,imm-val-true)
        ,l0)))
   
  ;; Expr Expr Expr CEnv -> Asm
  (define (compile-if e0 e1 e2 c)
    (let ((c0 (compile-e e0 c))
          (c1 (compile-e e1 c))
          (c2 (compile-e e2 c))
          (l0 (gensym))
          (l1 (gensym)))
      `(,@c0
        (cmp rax ,imm-val-false)
        (je ,l0)
        ,@c1
        (jmp ,l1)
        ,l0
        ,@c2
        ,l1)))
   
  ;; Expr Expr Expr CEnv -> Asm
  (define (compile-tail-if e0 e1 e2 c)
    (let ((c0 (compile-e e0 c))
          (c1 (compile-tail-e e1 c))
          (c2 (compile-tail-e e2 c))
          (l0 (gensym))
          (l1 (gensym)))
      `(,@c0
        (cmp rax ,imm-val-false)
        (je ,l0)
        ,@c1
        (jmp ,l1)
        ,l0
        ,@c2
        ,l1)))
   
  ;; Variable Expr Expr CEnv -> Asm
  (define (compile-tail-let b e1 c)
    (match b
      [(binding v def)
         (let ((c0 (compile-e def c))
               (c1 (compile-tail-e e1 (cons v c))))
           `(,@c0
             (mov (offset rsp ,(- (add1 (length c)))) rax)
             ,@c1))]
      [_ (error "Compile-let can only handle bindings")]))
   
  ;; Variable Expr Expr CEnv -> Asm
  (define (compile-let b e1 c)
    (match b
      [(binding v def)
         (let ((c0 (compile-e def c))
               (c1 (compile-e e1 (cons v c))))
           `(,@c0
             (mov (offset rsp ,(- (add1 (length c)))) rax)
             ,@c1))]
      [_ (error "Compile-let can only handle bindings")]))
   
  ;; Expr Expr CEnv -> Asm
  (define (compile-+ e0 e1 c)
    (let ((c1 (compile-e e1 c))
          (c0 (compile-e e0 (cons #f c))))
      `(,@c1
        ,@assert-integer
        (mov (offset rsp ,(- (add1 (length c)))) rax)
        ,@c0
        ,@assert-integer
        (add rax (offset rsp ,(- (add1 (length c))))))))
   
   
  (define (type-pred->mask p)
    (match p
      [(or 'box? 'cons? 'string? 'procedure?) result-type-mask]
      [_ imm-type-mask]))
   
  (define (type-pred->tag p)
    (match p
      ['box?       type-box]
      ['cons?      type-pair]
      ['string?    type-string]
      ['procedure? type-proc]
      ['integer?   imm-type-int]
      ['empty?     imm-type-empty]
      ['char?      imm-type-char]
      ['boolean?   imm-type-bool]))
   
  ;; Variable CEnv -> Natural
  (define (lookup x cenv)
    (match cenv
      ['() (error "undefined variable:" x)]
      [(cons y cenv)
       (match (eq? x y)
         [#t (length cenv)]
         [#f (lookup x cenv)])]))
   
  (define (assert-type p)
    `((mov rbx rax)
      (and rbx ,(type-pred->mask p))
      (cmp rbx ,(type-pred->tag p))
      (jne err)))
   
  (define assert-integer (assert-type 'integer?))
  (define assert-box     (assert-type 'box?))
  (define assert-pair    (assert-type 'cons?))
  (define assert-string  (assert-type 'string?))
  (define assert-char    (assert-type 'char?))
  (define assert-proc    (assert-type 'procedure?))
   
  ;; Asm
  (define assert-natural
    `(,@assert-integer
      (cmp rax -1)
      (jle err)))
   
  ;; Asm
  (define assert-integer-codepoint
    `((mov rbx rax)
      (and rbx ,imm-type-mask)
      (cmp rbx 0)
      (jne err)
      (cmp rax ,(arithmetic-shift -1 imm-shift))
      (jle err)
      (cmp rax ,(arithmetic-shift #x10FFFF imm-shift))
      (mov rbx rax)
      (sar rbx ,(+ 11 imm-shift))
      (cmp rbx #b11011)
      (je err)))
   
  ;; Symbol -> Label
  ;; Produce a symbol that is a valid Nasm label
  (define (symbol->label s)
    (string->symbol
     (string-append
      "label_"
      (list->string
       (map (λ (c)
              (if (or (char<=? #\a c #\z)
                      (char<=? #\A c #\Z)
                      (char<=? #\0 c #\9)
                      (memq c '(#\_ #\$ #\# #\@ #\~ #\. #\?)))
                  c
                  #\_))
           (string->list (symbol->string s))))
      "_"
      (number->string (eq-hash-code s) 16))))