Journal


Free Monad - Interpreter pattern in F#

An analysis of the Free Monad - Interpreter pattern in F# from a definition created by erdeszt and based on: http://programmers.stackexchange.com/a/242803/145941

The DSL

First we define a DSL for our actions. Each action points to the next action using the 'next generic type, every action has a 'next that in turn could be a DSL, thus chaining them.

1: 
2: 
3: 
type DSL<'next> =
    | Set of key: string * value: string *  'next
    | Get of key: string *       (string -> 'next)
  • Get returns a string which is passed to a function. id function can be used to finish the chain.
  • Set doesn't return anything, so the 'next portion is the next DSL element in the chain, or a constant like () to finish.

Note that used this way 'next can be anything. It does not have to be a DSL value, so there is no real implication of a chain of DSLs.

This is what 3 actions in the DSL may look like.

1: 
2: 
3: 
4: 
5: 
let ex1 = Set ("name", "John"
             , Get ("name"
                  , fun name -> Set ("greeting", sprintf "Hello %s" name, () )
                   )
              )

val ex1 : DSL<DSL<DSL<unit>>> = Set ("name","John",Get ("name",<fun:ex1@23-4>))

Notice how the resulting type DSL<DSL<DSL<unit>>> is nested and not generic. This means a strongly type function cannot process all posible values.

The Free Monad

Here comes the Free Monad ChainDSL to the rescue.

1: 
2: 
3: 
type ChainDSL<'a> =
    | Do     of DSL<ChainDSL<'a>>
    | Return of 'a

The Do option creates a chain of ChainDSLs that ends with the Return option. This chain ends up having a type equal to the last DSL in the chain. This is almost like creating a List of DSLs (List<DSL<'a>> or DSL<'a> list), except that each DSL in the chain can be of a different type except that each element of the DSL chains to the next and some elements are chained using functions.

Lets look at the same value above with ChainDSL:

1: 
2: 
3: 
4: 
5: 
6: 
7: 
let exF1 = Do (Set ("name", "John"
                  , Do (Get ("name"
                           , fun name -> Do (Set ("greeting", sprintf "Hello %s" name, Return () )) 
                            )
                       )
                   )
              )

val exF1 : ChainDSL<unit> = Do (Set ("name","John",Do (Get ("name",<fun:exF1@49-3>))))

Compare the resulting type with the prior case: ChainDSL<unit> vs DSL<DSL<DSL<unit>>>. No matter how deep the chain is, the value will always be of type ChainDSL<unit> or ChainDSL<string>.

But creating the chain is much more complex than before. To solve that, lets create two helper functions: get and set.

1: 
2: 
let get key       = Do (Get (key, fun value -> Return value))
let set key value = Do (Set (key,     value,   Return ()   ))    

val get : key:string -> ChainDSL<string>

val set : key:string -> value:string -> ChainDSL<unit>

Notice get returns a ChainDSL<string> and set returns a ChainDSL<unit>. They both return a ChainDSL chain with a single DSL action.

With these functions we can create Get & Set operations like this:

1: 
2: 
3: 
let setName     name = set "name"     name
let getName          = get "name"
let setGreeting name = set "greeting" (sprintf "Hello %s" name)

but they are not chained together like before.

Binding it together

To chain them we will need to define a bind function for the ChainDSL. We start with a map function for DSL, thus making DSL a functor:

1: 
2: 
3: 
4: 
5: 
let mapDSL: ('a -> 'b) -> DSL<'a> -> DSL<'b> = 
    fun     f             action  ->
        match action with
        | Get (key,        fNext) -> Get (key,        fNext >> f)
        | Set (key, value,  next) -> Set (key, value,  next |> f)

All mapDSL does is apply the function f to the 'next part of the DSL. In other words go to the next node in the chain.

Next we define the bind function for ChainDSL, finally making it a monad:

1: 
2: 
3: 
4: 
5: 
6: 
7: 
let bindChain: ('a -> ChainDSL<'b>) -> ChainDSL<'a> -> ChainDSL<'b> =
    fun        fChain                  chainTo      ->
        let rec appendTo chain =
            match chain with
            | Return a   -> fChain a
            | Do     dsl -> Do (mapDSL appendTo dsl)
        appendTo chainTo

bindChain is similar and acts like the List.append function, it concatenates two chains of ChainDSLs. The difference is that the chain to be appended fChain is passed within a function. bindChain navigates recursively down chainTo and replaces the last element with the result of fChain:

  • On the Do side bindChain calls mapDSL to apply the function to the next ChainDSL node.
  • On the Return side it replaces the Return a for a call to the chain to be appended fChain.

In a sense ChainDSL is actually the opposite of a List<DSL<'a>>. In a List new elements are inserted at the head, here they are attached at the tail end.

Now we can bind setName, getName & setResult from above like this:

1: 
2: 
3: 
let exF2 = setName "John" 
           |> bindChain (fun _    -> getName         )
           |> bindChain (fun name -> setGreeting name)

val exF2 : ChainDSL<unit> = Do (Set ("name","John",Do (Get ("name",<fun:mapDSL@87-2>))))

which is the same as this:

1: 
2: 
3: 
let exF3 = set "name" "John" 
           |> bindChain (fun _ -> get "name"                           )
           |> bindChain (fun v -> set "greeting" (sprintf "Hello %s" v))

val exF3 : ChainDSL<unit> = Do (Set ("name","John",Do (Get ("name",<fun:mapDSL@87-2>))))

and this:

1: 
2: 
3: 
4: 
5: 
let (>>=) v f = bindChain f v

let exF4 = set "name" "John" 
           >>= fun _    -> get "name" 
           >>= fun name -> set "greeting" (sprintf "Hello %s" name)

val exF4 : ChainDSL<unit> = Do (Set ("name","John",Do (Get ("name",<fun:mapDSL@87-2>))))

Using Computational Expressions

Now lets try it with Computational Expressions. First we define a builder class.

1: 
2: 
3: 
4: 
5: 
6: 
7: 
type ChainDSLBuilder () =
    member this.Return      v = Return v
    member this.ReturnFrom mv = mv
    member this.Zero       () = Return ()
    member this.Bind   (v, f) = v >>= f

let chainDSL = ChainDSLBuilder ()

And now we use the computational expression like this.

1: 
2: 
3: 
4: 
5: 
let exF5 = chainDSL {
    do!         set "name"     "John"
    let! name = get "name"
    do!         set "greeting" (sprintf "Hello %s" name)
}

val exF5 : ChainDSL<unit> = Do (Set ("name","John",Do (Get ("name",<fun:mapDSL@87-2>))))

The Interpreters

Now we are going to create an interpreter to execute the AST created.

This first version is very simple it does not store or retrieve any values, just prints out the commands.

 1: 
 2: 
 3: 
 4: 
 5: 
 6: 
 7: 
 8: 
 9: 
10: 
11: 
12: 
13: 
14: 
let rec interpreter1: ChainDSL<'a> -> 'a =
    fun               chain        ->
        match chain with
        | Return v -> printfn "return %A" v
                      v
        | Do   dsl -> 
            match dsl with
            | Get(key,        nextF) -> printfn "Get %s" key
                                        nextF (sprintf "<get.%s>" key) 
            | Set(key, value, next ) -> printfn "Set %s '%s'" key value
                                        next                           
            |> interpreter1

interpreter1 exF5
Set name 'John'
Get name
Set greeting 'Hello <get.name>'
return <null>

This next version actually stores and retrieves the values in a Map object, and when finished prints its content.

 1: 
 2: 
 3: 
 4: 
 5: 
 6: 
 7: 
 8: 
 9: 
10: 
11: 
12: 
13: 
14: 
15: 
16: 
17: 
18: 
19: 
20: 
let interpreter2 chain = 
    let rec interpreter2r: Map<string, string> -> ChainDSL<'a> -> 'a =
        fun                dataStore              chain        ->
            match chain with
            | Return v -> printfn "return %A\n%A" v dataStore
                          v
            | Do   dsl -> 
                match dsl with
                | Get(key,        nextF) -> dataStore 
                                            |> Map.find key 
                                            |> (fun v -> printfn "Get %s -> '%s'" key v ; v )
                                            |> nextF
                                            |> interpreter2r dataStore
                | Set(key, value, next ) -> printfn "Set %s '%s'" key value
                                            next
                                            |> interpreter2r (dataStore |> Map.add key value)

    interpreter2r (Map.ofList []) chain

interpreter2 exF5
Set name 'John'
Get name -> 'John'
Set greeting 'Hello John'
return <null>
map [("greeting", "Hello John"); ("name", "John")]

A slightly longer example:

 1: 
 2: 
 3: 
 4: 
 5: 
 6: 
 7: 
 8: 
 9: 
10: 
chainDSL {
    do!           set "first-name" "John"
    do!           set "last-name"  "Smith"
    let! first  = get "first-name"
    let! last   = get "last-name"
    do!           set "full-name" (first + " "  + last)
    let! full   = get "full-name"
    return        sprintf "Hello %s" full
}
|> interpreter2

Output:

Set first-name 'John'
Set last-name 'Smith'
Get first-name -> 'John'
Get last-name -> 'Smith'
Set full-name 'John Smith'
Get full-name -> 'John Smith'
return "Hello John Smith"
map
  [("first-name", "John"); ("full-name", "John Smith"); ("last-name", "Smith")]

Return value:

"Hello John Smith"

Trying to replicate this last example without the computational expression requires explicitly nesting some of the calls.

It would look like this:

1: 
2: 
3: 
4: 
5: 
6: 
7: 
8: 
set "first-name" "John" 
>>= fun _     -> set "last-name"  "Smith"            
>>= fun _     -> get "first-name"                    
>>= fun first -> get "last-name" 
                 >>= fun last -> set "full-name" (first + " "  + last)
>>= fun _     -> get "full-name"
>>= fun full  -> Return (sprintf "Hello %s" full)
|> interpreter2

Output:

Set first-name 'John'
Set last-name 'Smith'
Get first-name -> 'John'
Get last-name -> 'Smith'
Set full-name 'John Smith'
Get full-name -> 'John Smith'
return "Hello John Smith"
map
  [("first-name", "John"); ("full-name", "John Smith"); ("last-name", "Smith")]

Return value:

"Hello John Smith"

Two in one

So, do we really need two types, the Free Monad and the DSL?

I do not think it is necessary, the free monad helps in binding the elements of the DSL. The same can be achieved just by adding the Return option to the DSL.

Here is the same implementation with just the DSL type:

 1: 
 2: 
 3: 
 4: 
 5: 
 6: 
 7: 
 8: 
 9: 
10: 
11: 
12: 
13: 
14: 
15: 
16: 
17: 
18: 
19: 
20: 
21: 
22: 
23: 
24: 
25: 
26: 
27: 
28: 
29: 
30: 
31: 
32: 
33: 
34: 
35: 
module DSL2 =
    type DSL<'a> =
        | Set of key: string * value: string *  DSL<'a>
        | Get of key: string *       (string -> DSL<'a>)
        | Return of 'a
    
    let set key value = Set (key, value,          Return ())
    let get key       = Get (key,        fun v -> Return v )
    
    let bind: ('a -> DSL<'b>) -> DSL<'a> -> DSL<'b> =
        fun   fChain             chainTo ->
           let rec appendTo chain =
               match chain with
               | Set (k, v,  next) -> Set (k, v,  next |> appendTo)
               | Get (k,    fNext) -> Get (k,    fNext >> appendTo)
               | Return  v         -> fChain v
           appendTo chainTo

    let (>>=) v f = bind f v

    let interpreter2 dsl =
        let rec interpreter2r: Map<string, string> -> DSL<'a> -> 'a =
            fun                dataStore              dslR    ->
                match dslR with
                | Return v               -> printfn "return %A\n%A" v dataStore
                                            v
                | Get(key,        nextF) -> dataStore 
                                            |> Map.find key 
                                            |> (fun v -> printfn "Get %s -> '%s'" key v ; v )
                                            |> nextF
                                            |> interpreter2r dataStore
                | Set(key, value, next ) -> printfn "Set %s '%s'" key value
                                            next
                                            |> interpreter2r (dataStore |> Map.add key value)
        interpreter2r (Map.ofList []) dsl

There you have it the DSL definition, helper functions, the bind function and the interpreter. Here is the last example again.

1: 
2: 
3: 
4: 
5: 
6: 
7: 
8: 
    set "first-name" "John" 
    >>= fun _     -> set "last-name"  "Smith"            
    >>= fun _     -> get "first-name"                    
    >>= fun first -> get "last-name" 
                     >>= fun last -> set "full-name" (first + " "  + last)
    >>= fun _     -> get "full-name"
    >>= fun full  -> Return (sprintf "Hello %s" full)
    |> interpreter2

Output:

Set first-name 'John'
Set last-name 'Smith'
Get first-name -> 'John'
Get last-name -> 'Smith'
Set full-name 'John Smith'
Get full-name -> 'John Smith'
return "Hello John Smith"
map
  [("first-name", "John"); ("full-name", "John Smith"); ("last-name", "Smith")]

Return value:

"Hello John Smith"