1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
import scala.collection.mutable
import scala.util.continuations._

/**
 * This module defines a paradigm for data modeling and fetching in web
 * applications. The data model is that of a graph with "objects" as vertices
 * and "edges" connecting them. Each type of edge is an object in itself.
 * The fetching paradigm is to use continuations to maximize batch fetching of
 * large amounts of objects across an entire application at once.
 */
object Loader {

  /**
   * Captured state of a function that yields an R.
   * First case is an incomplete state; second case is the final result.
   */
  abstract class ExecState[+R] {
    def isContinuation = this match {
      case Continuation(_) => true
      case Result(_) => false
    }
  }
  case class Continuation[+R](val next: Unit => ExecState[R]) 
     extends ExecState[R]
  case class Result[+R](val result: R)
     extends ExecState[R]

  /**
   * This function takes as input a list of ExecState dependencies and a function
   * that transforms the satisfied dependencies into another ExecState. It does
   * not compute any of the dependencies; it simply builds and returns a new
   * ExecState that prepends the computation of the dependencies to the
   * computation of the final result. This function will recursively call itself
   * for each dependency, so in effect we are taking a DAG of dependencies and
   * converting it into a linked list of computations that may occur in parallel.
   *
   * The point of executing the dependencies in parallel is that they all require
   * data fetches. By executing each dependency one step at a time, we are able to
   * coalesce all of these data fetches together into a single batch.
   *
   * @param first The list of dependencies that must be calculated.
   * @param after The callback that requires the satisfied dependencies.
   * @return      A continuation callback that is ready to execute.
   */
  def fetch[R](
    first: ExecState[Any]*
  ) : ((Seq[Any] => ExecState[R]) => ExecState[R]) = {
    // After the dependencies are specified we return a higher level function.
    // The function takes a Scala CPS function and is what we pass to shift.
    (after: Seq[Any] => ExecState[R]) => {
      // Don't actually execute everything. Instead, return a function that will
      // execute the next step and, in turn, return a function that will execute
      // the next step: essentially, a classic continuation.
      val nextStep = (_: Unit) => {
        if (first.filter(_.isContinuation).isEmpty) {
          // All done? Call the callback that was waiting on the dependencies.
          // We will return the ExecState that is returned by that callback.
          after(first.collect { case Result(x) => x })
        } else {
          // Execute a single step through each dependency.
          val states = first.map((state: ExecState[Any]) => {
            state match {
              case Continuation(next) =>
                next()
              case Result(_) =>
                state
            }
          })
          // We will call ourselves recursively to generate the next step.
          fetch(states: _*)(after)
        }
      } : ExecState[R]
      // Wrap the function that will execute the next step.
      Continuation[R](nextStep)
    }
  }

  /**
   * This function will take a list of ExecStates and will process them until
   * they're done. Between each step it will execute a data fetch.
   */
  def process(states: ExecState[Any]*) : Seq[Any] = {
    // If there are no Continuations left we are done!
    if (states.filter(_.isContinuation).isEmpty) {
      states.collect { case Result(x) => x }
    } else {
      // The whole point of all of this continuation junk. This call fetches
      // all the data necessary for this batched computation step in one go.
      Object.query()
      val processed = states.map((state: ExecState[Any]) => {
        state match {
          case Continuation(next) =>
            // If this isn't the final step in calculating this dependency, this
            // call will queue up data fetches that will be needed for the next
            // calculation step. The actual data fetch is executed above.
            next()
          case Result(_) =>
            state
        }
      })
      // Call ourselves recursively!
      process(processed: _*)
    }
  }

  /**
   * As object as represented in the data model.
   * Essentially it has three parts:
   * - Unique ID
   * - Key-value store
   * - List of edges to other objects (fetched on-demand)
   */
  class Object private (
    _id: Long,
    _fields: Map[String, String],
    _edges: mutable.Map[Long, Long]
  ) {
    val id: Long = _id
    val fields: Map[String, String] = _fields
    private val edges: mutable.Map[Long, Long] = _edges
  }

  object Object {
    
    private val queued: mutable.Set[Long] = mutable.Set[Long]()
    private val cached: mutable.Map[Long, Object] = mutable.Map[Long, Object]()

    /**
     * Fetch an object from the database.
     */
    def get(id: Long) = reset {
      // Is it not already cached?
      if (!cached.isDefinedAt(id)) {
        // Store the ID we want to fetch.
        queued += id
        // Come back later...
        shift { fetch[Object]() } : Seq[Any] @cps[ExecState[Object]]
        // We should have the ID fetched now.
        Result(cached(id))
      } else {
        // We should have the ID fetched now.
        Result(cached(id))
      }
    }

    /**
     * Actually execute a database query.
     */
    def query() = {
      // TODO magic SQL
      cached ++= queued.map(id => 
        (id, new Object(
          id,
          Map[String, String]("test" -> "5"),
          mutable.Map[Long, Long]()
        ))
      )
      queued.clear()
    }

  }

}

// Script code for testing

import Loader._

val getFive = reset {
  val seq: Seq[Any] = shift { fetch[Int](Object.get(15181990251L)) }
  val Seq(obj: Object) = seq
  Result(obj.fields("test").toInt)
}

val getEight = reset {
  val seq: Seq[Any] = shift { fetch[Int](getFive, Result(3)) }
  val Seq(a: Int, b: Int) = seq
  Result(a + b)
}

val Seq(a: Int) = process(getEight)
println(a)