On this page:
15.1 Tail Calls
15.2 What is a Tail Call?
15.3 An Interpreter for Proper Calls
15.4 A Compiler with Proper Tail Calls
7.9

15 Jig: jumping to tail calls

    15.1 Tail Calls

    15.2 What is a Tail Call?

    15.3 An Interpreter for Proper Calls

    15.4 A Compiler with Proper Tail Calls

15.1 Tail Calls

With Iniquity, we’ve finally introduced some computational power via the mechanism of functions and function calls. Together with the notion of inductive data, which we have in the form of pairs, we can write fixed-sized programs that operate over arbitrarily large data.

The problem, however, is that there are a class of programs that should operate with a fixed amount of memory, but instead consume memory in proportion to the size of the data they operate on. This is unfortunate because a design flaw in our compiler now leads to asympototically bad run-times.

We can correct this problem by generating space-efficient code for function calls when those calls are in tail position.

Let’s call this language Jig.

There are no syntactic additions required: we simply will properly handling function calls.

15.2 What is a Tail Call?

A tail call is a function call that occurs in tail position. What is tail position and why is important to consider function calls made in this position?

Tail position captures the notion of “the last subexpression that needs to be computed.” If the whole program is some expression e, the e is in tail position. Computing e is the last thing (it’s the only thing!) the program needs to compute.

Let’s look at some examples to get a sense of the subexpressions in tail position. If e is in tail position and e is of the form:

The significance of tail position is relevant to the compilation of calls. Consider the compilation of a call as described in Iniquity: function definitions and calls: arguments are pushed on the call stack, then the 'call instruction is issued, which pushes the address of the return point on the stack and jumps to the called position. When the function returns, the return point is popped off the stack and jumped back to.

But if the call is in tail position, what else is there to do? Nothing. So after the call, return transfers back to the caller, who then just returns itself.

This leads to unconditional stack space consumption on every function call, even function calls that don’t need to consume space.

Consider this program:

; (Listof Number) -> Number
(define (sum xs) (sum/acc xs 0))
 
; (Listof Number) Number -> Number
(define (sum/acc xs a)
  (if (empty? xs)
      a
      (sum/acc (cdr xs) (+ (car xs) a))))

The sum/acc function should operate as efficiently as a loop that iterates over the elements of a list accumulating their sum. But, as currently compiled, the function will push stack frames for each call.

Matters become worse if we were re-write this program in a seemingly benign way to locally bind a variable:

; (Listof Number) Number -> Number
(define (sum/acc xs a)
  (if (empty? xs)
      a
      (let ((b (+ (car xs) a)))
        (sum/acc (cdr xs) b))))

Now the function pushes a return point and a local binding for b on every recursive call.

But we know that whatever the recursive call produces is the answer to the overall call to sum. There’s no need for a new return point and there’s no need to keep the local binding of b since there’s no way this program can depend on it after the recursive call. Instead of pushing a new, useless, return point, we should make the call with whatever the current return point. This is the idea of proper tail calls.

An axe to grind: the notion of proper tail calls is often referred to with misleading terminology such as tail call optimization or tail recursion. Optimization seems to imply it is a nice, but optional strategy for implementing function calls. Consequently, a large number of mainstream programming languages, most notably Java, do not properly implement tail calls. But a language without proper tail calls is fundamentally broken. It means that functions cannot reliably be designed to match the structure of the data they operate on. It means iteration cannot be expressed with function calls. There’s really no justification for it. It’s just broken. Similarly, it’s not about recursion alone (although it is critical for recursion), it really is about getting function calls, all calls, right. /rant

15.3 An Interpreter for Proper Calls

Before addressing the issue of compiling proper tail calls, let’s first think about the interpreter, starting from the interpreter we wrote for Iniquity:

iniquity/interp.rkt

  #lang racket
  (provide interp interp-env interp-prim1)
  (require "ast.rkt"
           "env.rkt"
           "interp-prims.rkt")
   
  ;; type Answer = Value | 'err
   
  ;; type Value =
  ;; | Integer
  ;; | Boolean
  ;; | Character
  ;; | Eof
  ;; | Void
  ;; | '()
  ;; | (cons Value Value)
  ;; | (box Value)
   
  ;; type REnv = (Listof (List Id Value))
  ;; type Defns = (Listof Defn)
   
  ;; Prog Defns -> Answer
  (define (interp p)
    (match p
      [(Prog ds e)
       (interp-env e '() ds)]))
   
  ;; Expr Env Defns -> Answer
  (define (interp-env e r ds)
    (match e
      [(Int i)  i]
      [(Bool b) b]
      [(Char c) c]
      [(Eof)    eof]
      [(Empty)  '()]
      [(Var x)  (lookup r x)]
      [(Prim0 'void) (void)]
      [(Prim0 'read-byte) (read-byte)]
      [(Prim0 'peek-byte) (peek-byte)]
      [(Prim1 p e)
       (match (interp-env e r ds)
         ['err 'err]
         [v (interp-prim1 p v)])]
      [(Prim2 p e1 e2)
       (match (interp-env e1 r ds)
         ['err 'err]
         [v1 (match (interp-env e2 r ds)
               ['err 'err]
               [v2 (interp-prim2 p v1 v2)])])]
      [(If p e1 e2)
       (match (interp-env p r ds)
         ['err 'err]
         [v
          (if v
              (interp-env e1 r ds)
              (interp-env e2 r ds))])]
      [(Begin e1 e2)
       (match (interp-env e1 r ds)
         ['err 'err]
         [_ (interp-env e2 r ds)])]
      [(Let x e1 e2)
       (match (interp-env e1 r ds)
         ['err 'err]
         [v (interp-env e2 (ext r x v) ds)])]
      [(App f es)
       (match (interp-env* es r ds)
         [(list vs ...)
          (match (defns-lookup ds f)
            [(Defn f xs e)
             ; check arity matches
             (if (= (length xs) (length vs))
                 (interp-env e (zip xs vs) ds)
                 'err)])]
         [_ 'err])]))
   
  ;; (Listof Expr) REnv Defns -> (Listof Value) | 'err
  (define (interp-env* es r ds)
    (match es
      ['() '()]
      [(cons e es)
       (match (interp-env e r ds)
         ['err 'err]
         [v (cons v (interp-env* es r ds))])]))
   
  ;; Defns Symbol -> Defn
  (define (defns-lookup ds f)
    (findf (match-lambda [(Defn g _ _) (eq? f g)])
           ds))
   
  (define (zip xs ys)
    (match* (xs ys)
      [('() '()) '()]
      [((cons x xs) (cons y ys))
       (cons (list x y)
             (zip xs ys))]))
   

What needs to be done to make it implement proper tail calls?

Well... not much. Notice how every Iniquity subexpression that is in tail position is interpreted by a call to interp-env that is itself in tail position in the Racket program!

So long as Racket implements tail calls properly, which is does, then this interpreter implements tail calls properly. The interpreter inherits the property of proper tail calls from the meta-language. This is but one reason to do tail calls correctly. Had we transliterated this program to Java, we’d be in trouble as the interpeter would inherit the lack of tail calls and we would have to re-write the interpreter, but as it is, we’re already done.

15.4 A Compiler with Proper Tail Calls

The compiler requires a bit more work, because of how the Call instruction is implemented in the hardware itself, we always use a little bit of stack space each time we execute a function call. Therefore, in order to implement tail-calls correctly, we need to avoid the Call instruction!

How do we perform function calls without the Call instruction, well we’re going to have to do a little bit of extra work in the compiler. First, let’s remind ourselves of how a ‘normal’ function call works (we’ll just look at the case where we don’t have to adjust for alignment):

(define (compile-app f es c)
 
         ; Generate the code for each argument
         ; and push each on the stack
    (seq (compile-es es c)
 
         ; Generate the instruction for calling the function itself
         (Call (symbol->label f))
 
         ; pop all of the arguments off of the stack
         (Add rsp (* 8 (length es)))))

The first insight regards what the stack will look like once we are inside the function we are calling. Upon entry to the function’s code, rsp will point to the return address that the last Call instruction pushed onto the stack, with the arguments to the function at positive offsets to rsp. As long as we ensure that this is the case we don’t have to call functions with Call.

The second insight is what we mentioned above, when describing tail calls themselves: If we’re performing a call in the tail position then there is nothing else to do when we return. So instead of returning here, we can return to the previous call, we can overwrite the current environment on the stack, since we won’t need it (there’s nothing else to do, after all). In jargon: we can reuse the stack frame. The only thing we have to be careful about is whether the current environment is ‘big enough’ to hold all of the arguments for our function call, since we are going to reuse it, we’ll want to make sure there’s enough space.

For now assume we’ve performed that check and that there is enough space. Let’s go through the process bit by bit:

; Variable (Listof Expr) CEnv -> Asm
; Compile a call in tail position
(define (compile-tail-call f es c)
  (let ((cnt (length es)))
 
            ; Generate the code for the arguments to the function,
            ; pushing them on the stack, this is no different
            ; than a normal call
       (seq (compile-es es c)
 
 
            ; Now we _move_ the arguments from where they are on the
            ; stack to where the _previous_ values in the environment
            ; the function move-args takes the number of values we
            ; have to move, and the number of stack slots that we have to
            ; move them.
            (move-args cnt (+ cnt (in-frame c)))
 
            ; Once we've moved the arguments, we no longer need them at the
            ; top of the stack. This is a big part of the benefit for
            ; tail-calls
            (Add rsp (* 8 (+ cnt (in-frame c))))
 
            ; Now that rsp points to the _previous_ return address,
            ; and the arguments are at a positive offset of rsp,
            ; we no longer need the call instruction (in fact, it would
            ; be incorrect to use it!), instead we jump to the function
            ; directly.
            (Jmp (symbol->label f)))))

move-args is defined below:

; Integer Integer -> Asm
; Move i arguments upward on stack by offset off
(define (move-args i cnt)
  (match i
    [0 (seq)]
    [_ (seq
         ; mov first arg to temp reg
         (Mov r9 (Offset rsp (* 8 (sub1 i))))
         ; mov value to correct place on the old frame
         (Mov (Offset rsp (* 8 (+ i cnt))) r9)
         ; Now do the next one
         (move-args (sub1 i) cnt))]))

The entire compiler will be illuminated for seeing how we keep track of which expressions are in a tail-call position and whether we have enough space to re-use the stack frame.

jig/compile.rkt

  #lang racket
  (provide (all-defined-out))
  (require "ast.rkt" "types.rkt" a86/ast)
   
  ;; Registers used
  (define rax 'rax) ; return
  (define rbx 'rbx) ; heap
  (define rdx 'rdx) ; return, 2
  (define r8  'r8)  ; scratch in +, -
  (define r9  'r9)  ; scratch in assert-type and tail-calls
  (define rsp 'rsp) ; stack
  (define rdi 'rdi) ; arg
   
  ;; type CEnv = [Listof Variable]
   
  ;; Expr -> Asm
  (define (compile p)
    (match p
      [(Prog ds e)  
       (prog (Extern 'peek_byte)
             (Extern 'read_byte)
             (Extern 'write_byte)
             (Extern 'raise_error)
             (Label 'entry)
             (Mov rbx rdi) ; recv heap pointer
             (compile-e e '(#f)) ; NOT A TAIL CALL! We can't re-use the frame!!!
             (Mov rdx rbx) ; return heap pointer in second return register           
             (Ret)
             (compile-defines ds))]))
   
  ;; [Listof Defn] -> Asm
  (define (compile-defines ds)
    (seq
      (match ds
        ['() (seq)]
        [(cons d ds)
         (seq (compile-define d)
              (compile-defines ds))])))
    
  ;; Defn -> Asm
  (define (compile-define d)
    (match d
      [(Defn f xs e)
                          ; leave space for RIP
       (let ((env (parity (cons #f (reverse xs)))))
            (seq (Label (symbol->label f))
                 ; we need the #args on the frame, not the length of the entire
                 ; env (which may have padding)
                 (compile-tail-e e env (length xs))
                 (Ret)))]))
   
  (define (parity c)
    (if (even? (length c))
        (append c (list #f))
        c))
   
  ;; Expr Expr Expr CEnv Int -> Asm
  (define (compile-tail-e e c s)
    (seq
      (match e
        [(If e1 e2 e3) (compile-tail-if e1 e2 e3 c s)]
        [(Let x e1 e2) (compile-tail-let x e1 e2 c s)]
        [(App f es)    (if (<= (length es) s)
                           (compile-tail-call f es c)
                           (compile-app f es c))]
        [(Begin e1 e2) (compile-tail-begin e1 e2 c s)]
        [_             (compile-e e c)])))
   
  ;; Expr CEnv -> Asm
  (define (compile-e e c)
    (seq
         (match e
           [(? imm? i)      (compile-value (get-imm i))]
           [(Var x)         (compile-variable x c)]
           [(App f es)      (compile-app f es c)]    
           [(Prim0 p)       (compile-prim0 p c)]
           [(Prim1 p e)     (compile-prim1 p e c)]
           [(Prim2 p e1 e2) (compile-prim2 p e1 e2 c)]
           [(If e1 e2 e3)   (compile-if e1 e2 e3 c)]
           [(Begin e1 e2)   (compile-begin e1 e2 c)]
           [(Let x e1 e2)   (compile-let x e1 e2 c)])))
   
  ;; Value -> Asm
  (define (compile-value v)
    (seq (Mov rax (imm->bits v))))
   
  ;; Id CEnv -> Asm
  (define (compile-variable x c)
    (let ((i (lookup x c)))       
      (seq (Mov rax (Offset rsp i)))))
   
  ;; Op0 CEnv -> Asm
  (define (compile-prim0 p c)
    (match p
      ['void      (seq (Mov rax val-void))]
      ['read-byte (seq (pad-stack c)
                       (Call 'read_byte)
                       (unpad-stack c))]
      ['peek-byte (seq (pad-stack c)
                       (Call 'peek_byte)
                       (unpad-stack c))]))
   
  ;; Op1 Expr CEnv -> Asm
  (define (compile-prim1 p e c)
    (seq (compile-e e c)
         (match p
           ['add1
            (seq (assert-integer rax)
                 (Add rax (imm->bits 1)))]
           ['sub1
            (seq (assert-integer rax)
                 (Sub rax (imm->bits 1)))]         
           ['zero?
            (let ((l1 (gensym)))
              (seq (assert-integer rax)
                   (Cmp rax 0)
                   (Mov rax val-true)
                   (Je l1)
                   (Mov rax val-false)
                   (Label l1)))]
           ['char?
            (let ((l1 (gensym)))
              (seq (And rax mask-char)
                   (Xor rax type-char)
                   (Cmp rax 0)
                   (Mov rax val-true)
                   (Je l1)
                   (Mov rax val-false)
                   (Label l1)))]
           ['char->integer
            (seq (assert-char rax)
                 (Sar rax char-shift)
                 (Sal rax int-shift))]
           ['integer->char
            (seq assert-codepoint
                 (Sar rax int-shift)
                 (Sal rax char-shift)
                 (Xor rax type-char))]
           ['eof-object? (eq-imm val-eof)]
           ['write-byte
            (seq assert-byte
                 (pad-stack c)
                 (Mov rdi rax)
                 (Call 'write_byte)
                 (unpad-stack c)
                 (Mov rax val-void))]
           ['box
            (seq (Mov (Offset rbx 0) rax)
                 (Mov rax rbx)
                 (Or rax type-box)
                 (Add rbx 8))]
           ['unbox
            (seq (assert-box rax)
                 (Xor rax type-box)
                 (Mov rax (Offset rax 0)))]
           ['car
            (seq (assert-cons rax)
                 (Xor rax type-cons)
                 (Mov rax (Offset rax 8)))]
           ['cdr
            (seq (assert-cons rax)
                 (Xor rax type-cons)
                 (Mov rax (Offset rax 0)))]
           ['empty? (eq-imm val-empty)])))
   
  ;; Op2 Expr Expr CEnv -> Asm
  (define (compile-prim2 p e1 e2 c)
    (seq (compile-e e1 c)
         (Push rax)
         (compile-e e2 (cons #f c))
         (match p
           ['+
            (seq (Pop r8)
                 (assert-integer r8)
                 (assert-integer rax)
                 (Add rax r8))]
           ['-
            (seq (Pop r8)
                 (assert-integer r8)
                 (assert-integer rax)
                 (Sub r8 rax)
                 (Mov rax r8))]
           ['eq?
            (let ((l (gensym)))
              (seq (Cmp rax (Offset rsp 0))
                   (Sub rsp 8)
                   (Mov rax val-true)
                   (Je l)
                   (Mov rax val-false)
                   (Label l)))]
           ['cons
            (seq (Mov (Offset rbx 0) rax)
                 (Pop rax)
                 (Mov (Offset rbx 8) rax)
                 (Mov rax rbx)
                 (Or rax type-cons)
                 (Add rbx 16))])))
   
  ;; Id [Listof Expr] CEnv -> Asm
  ;; Here's why this code is so gross: you have to align the stack for the call
  ;; but you have to do it *before* evaluating the arguments es, because you need
  ;; es's values to be just above 'rsp when the call is made.  But if you push
  ;; a frame in order to align the call, you've got to compile es in a static
  ;; environment that accounts for that frame, hence:
  (define (compile-app f es c)
    (if (even? (+ (length es) (length c))) 
        (seq (compile-es es c)
             (Call (symbol->label f))
             (Add rsp (* 8 (length es))))            ; pop args
        (seq (Sub rsp 8)                             ; adjust stack
             (compile-es es (cons #f c))
             (Call (symbol->label f))
             (Add rsp (* 8 (add1 (length es)))))))   ; pop args and pad
   
   
  ;; Variable (Listof Expr) CEnv -> Asm
  ;; Compile a call in tail position
  (define (compile-tail-call f es c)
    (let ((cnt (length es)))
         (seq (compile-es es c)
              (move-args cnt (+ cnt (in-frame c)))
              (Add rsp (* 8 (+ cnt (in-frame c))))
              (Jmp (symbol->label f)))))
   
  ;; Integer Integer -> Asm
  ;; Move i arguments upward on stack by offset off
  (define (move-args i cnt)
    (match i
      [0 (seq)]
      [_ (seq
           ; mov first arg to temp reg
           (Mov r9 (Offset rsp (* 8 (sub1 i))))
           ; mov value to correct place on the old frame
           (Mov (Offset rsp (* 8 (+ i cnt))) r9)
           ; Now do the next one
           (move-args (sub1 i) cnt))]))
   
  ;; [Listof Expr] CEnv -> Asm
  (define (compile-es es c)
    (match es
      ['() '()]
      [(cons e es)
       (seq (compile-e e c)
            (Push rax)
            (compile-es es (cons #f c)))]))
   
  ;; Imm -> Asm
  (define (eq-imm imm)
    (let ((l1 (gensym)))
      (seq (Cmp rax imm)
           (Mov rax val-true)
           (Je l1)
           (Mov rax val-false)
           (Label l1))))
   
  ;; Expr Expr Expr CEnv -> Asm
  (define (compile-if e1 e2 e3 c)
    (let ((l1 (gensym 'if))
          (l2 (gensym 'if)))
      (seq (compile-e e1 c)
           (Cmp rax val-false)
           (Je l1)
           (compile-e e2 c)
           (Jmp l2)
           (Label l1)
           (compile-e e3 c)
           (Label l2))))
   
  ;; Expr Expr Expr CEnv -> Asm
  (define (compile-tail-if e1 e2 e3 c s)
    (let ((l1 (gensym 'if))
          (l2 (gensym 'if)))
      (seq (compile-e e1 c)
           (Cmp rax val-false)
           (Je l1)
           (compile-tail-e e2 c s)
           (Jmp l2)
           (Label l1)
           (compile-tail-e e3 c s)
           (Label l2))))
   
  ;; Expr Expr CEnv -> Asm
  (define (compile-begin e1 e2 c)
    (seq (compile-e e1 c)
         (compile-e e2 c)))
   
  ;; Expr Expr CEnv -> Asm
  (define (compile-tail-begin e1 e2 c s)
    (seq (compile-e e1 c)
         (compile-tail-e e2 c s)))
   
  ;; Id Expr Expr CEnv -> Asm
  (define (compile-let x e1 e2 c)
    (seq (compile-e e1 c)
         (Push rax)
         (compile-e e2 (cons x c))
         (Add rsp 8)))
   
  ;; Id Expr Expr CEnv -> Asm
  (define (compile-tail-let x e1 e2 c s)
    (seq (compile-e e1 c)
         (Push rax)
         (compile-tail-e e2 (cons x c) s)
         (Add rsp 8)))
   
  ;; CEnv -> Asm
  ;; Pad the stack to be aligned for a call with stack arguments
  (define (pad-stack-call c i)
    (match (even? (+ (length c) i))
      [#f (seq (Sub rsp 8) (% "padding stack"))]
      [#t (seq)]))
   
  ;; CEnv -> Asm
  ;; Pad the stack to be aligned for a call
  (define (pad-stack c)
    (pad-stack-call c 0))
   
  ;; CEnv -> Asm
  ;; Undo the stack alignment after a call
  (define (unpad-stack-call c i)
    (match (even? (+ (length c) i))
      [#f (seq (Add rsp 8) (% "unpadding"))]
      [#t (seq)]))
   
  ;; CEnv -> Asm
  ;; Undo the stack alignment after a call
  (define (unpad-stack c)
    (unpad-stack-call c 0))
   
  ;; Id CEnv -> Integer
  (define (lookup x cenv)
    (match cenv
      ['() (error "undefined variable:" x " Env: " cenv)]
      [(cons y rest)
       (match (eq? x y)
         [#t 0]
         [#f (+ 8 (lookup x rest))])]))
   
  (define (in-frame cenv)
    (match cenv
      ['() 0]
      [(cons #f rest) 0]
      [(cons y rest)  (+ 1 (in-frame rest))]))
   
  (define (assert-type mask type)
    (λ (arg)
      (seq (Mov r9 arg)
           (And r9 mask)
           (Cmp r9 type)
           (Jne 'raise_error))))
   
  (define (type-pred mask type)
    (let ((l (gensym)))
      (seq (And rax mask)
           (Cmp rax type)
           (Mov rax (imm->bits #t))
           (Je l)
           (Mov rax (imm->bits #f))
           (Label l))))
           
  (define assert-integer
    (assert-type mask-int type-int))
  (define assert-char
    (assert-type mask-char type-char))
  (define assert-box
    (assert-type ptr-mask type-box))
  (define assert-cons
    (assert-type ptr-mask type-cons))
   
  (define assert-codepoint
    (let ((ok (gensym)))
      (seq (assert-integer rax)
           (Cmp rax (imm->bits 0))
           (Jl 'raise_error)
           (Cmp rax (imm->bits 1114111))
           (Jg 'raise_error)
           (Cmp rax (imm->bits 55295))
           (Jl ok)
           (Cmp rax (imm->bits 57344))
           (Jg ok)
           (Jmp 'raise_error)
           (Label ok))))
         
  (define assert-byte
    (seq (assert-integer rax)
         (Cmp rax (imm->bits 0))
         (Jl 'raise_error)
         (Cmp rax (imm->bits 255))
         (Jg 'raise_error)))
         
  ;; Symbol -> Label
  ;; Produce a symbol that is a valid Nasm label
  (define (symbol->label s)
    (string->symbol
     (string-append
      "label_"
      (list->string
       (map (λ (c)
              (if (or (char<=? #\a c #\z)
                      (char<=? #\A c #\Z)
                      (char<=? #\0 c #\9)
                      (memq c '(#\_ #\$ #\# #\@ #\~ #\. #\?)))
                  c
                  #\_))
           (string->list (symbol->string s))))
      "_"
      (number->string (eq-hash-code s) 16))))