Surrounding code with common logic in Scala

A common programming need is to surround repeating logic around varying use cases. For example, all database actions need to be in a transaction, which should also roll back the modifications in case of an error.

I tried to see if I could write a generic function that could be used as easily as this:

transactional {
  save("something")
}

My first attempt was like this, thinking I needed an implicit variable for Connection, that needs to be passed from the transactional to the save function, but should not show up in the calling code:

def transactional(f: => Unit): Unit = {
  implicit val conn: Connection = ???

  conn.begin()
  try {
    f
    conn.commit()
  } catch {
    case ex: Exception =>
      conn.rollback()
      throw ex
  } finally {
    conn.close()
  }
}

def save(operation: String)(implicit conn: Connection): Unit = {
  println(s"saving $operation using $conn")
}

Unfortunately, this will not compile, as the Connection is not passed under the hood from the transactional to the save function.

This version is simpler, without implicits, returning a function from Connection to Unit, and it works in the calling code snippet:

def transactional(f: Connection => Unit): Unit = {
  val conn: Connection = ???

  conn.begin()
  try {
    f(conn)
    conn.commit()
  } catch {
    case ex: Exception =>
      conn.rollback()
      throw ex
  } finally {
    conn.close()
  }
}

def save(operation: String): Connection => Unit = { conn =>
  println(s"saving $operation using $conn")
}

transactional {
  save("something")
} 

However, once I started expanding the calling code, it no longer worked:

transactional {
  for (i <- 1 to 3) {
    save("something")
  }
} 

The for loop is no longer a function from Connection to Unit, so this will not compile.

Back to implicits, here is a version that will allow any block of code, and implicitly passes the Connection variable to any function that needs it:

def transactional(f: Connection => Unit): Unit = {
  val conn: Connection = ???

  conn.begin()
  try {
    f(conn)
    conn.commit()
  } catch {
    case ex: Exception =>
      conn.rollback()
      throw ex
  } finally {
    conn.close()
  }
}

def save(operation: String)(implicit conn: Connection): Unit = {
  println(s"saving $operation using $conn")
}

transactional { implicit conn =>
  for (i <- 1 to 3 ) {
    save("something")
  }
}

This will begin the connection, print "saving something..." 3 times, commit and close it, just as intended.