all that jazz

james' blog about scala and all that jazz

Three things I had to learn about Scala before it made sense

Last year I sat down in a meeting room with a few other colleagues to learn from one of Atlassian's biggest Scala advocates about Scala. I kept on hearing good stuff about Scala, and I like trying new things, so I was eager to have my eyes opened so that I could learn how awesome this thing was. Unfortunately, I was disapponted. Most of the lesson went straight over my head. But to my dismay, the other people in the room seemed to get it. It was on this day that I realised an interesting thing about my learning style. I remembered back at uni, I was exactly the same in lectures... I never got anything I was taught in lectures. Which is probably why by the end of uni, I had pretty much stopped attending all CS lectures. It wasn't until I was given a real problem in a lab that I had to solve that I got it. This is why I never really got functional languages at uni, the lecturers that taught functional languages at my uni weren't the most practical of people. Actually, I don't think they ever even considered whether their work had any practical implications at all. So they taught things like map, reduce, filter, fold, curried functions etc like this "Look, here's something you've never heard of! Behold its awesomeness! Now here's a trivial mathematical use case with no apparent real world application! Look at how practical it is!"

Some people have no problems learning this way, and I envy them. For me, I need to have a problem that I understand and want to solve today, that the thing that I'm learning can solve, in order to learn it. So if you're reading this and thinking "That is so me!", and you want to learn Scala, then hopefully this is the blog post for you. I'm going to start with a practical problem that hopefully you will understand and want to solve, and then show you how three particular Scala features allow solving this really elegantly and simply. Hopefully after that, you'll have a good understanding of those features in Scala, and be able to understand a lot more of Scala code when you encounter it.

The problem

The problem that I want to solve is dependency management. If you've ever used maven, you'll know how verbose this can be. For those that haven't encountered maven, here's what specifying a single dependency looks like:

<dependency>
    <groupId>net.vz.mongodb.jackson</groupId>
    <artifactId>mongo-jackson-mapper</artifactId>
    <version>1.4.1</version>
</dependency>

In English, this means a dependency on the artifact mongo-jackson-mapper version 1.4.1 from the organisation net.vz.mongodb.jackson. The English explanation is shorter than the code! Maybe the Maven authors could have made things nicer, by placing it on one line, and that would be clearer too, for example, something like this:

<dependency>net.vz.mongodb.jackson:mongo-jackson-mapper:1.4.1</dependency>

There are however some advantages to the more verbose model, the language (XML) that the DSL (maven pom.xml format) uses is able to fully express every element of the dependency, rather than the elements being defined in some other format XML is unaware of. There are also other attributes (such as scope and classifier) that the more verbose solution can easily unambiguously specify, these become very hard if it's just an arbitrary string of characters.

So here is the problem. I want a DSL to specify dependencies with as little noise as possible from the DSL, so that my dependency specification is short, unambiguous, and its format should be enforced by the language it's specified in. The language I'm going to use is Scala.

And just to ground this problem even more solidly in reality, this example is not contrived. SBT, the Scala Build Tool, uses Scala to specify its build configuration, including dependencies. The solution I'm going to show you is the solution that SBT uses. So by the end of this you will also understand a little about how to use SBT, which you will no doubt need to know as you delve into the world of Scala.

How would we do it with no weird Scala tricks?

First of all I will give a solution that uses language constructs that you should already be familiar with, using classes, fields, methods etc. If the following syntax makes no sense to you, you should probably do some reading on basic Scala syntax.

def groupId(groupId: String) = new GroupId(groupId)

class GroupId(val groupId: String) {
  def artifact(artifactId: String) = new Artifact(groupId, artifactId)
}

class Artifact(val groupId: String, val artifactId: String) {
  def version(version: String) = new VersionedArtifact(groupId, artifactId, version)
}

class VersionedArtifact(val groupId: String, val artifactId: String, val version: String) {
}

Here we have three classes, a GroupId, which has a method for creating an Artifact from that group id, which in turn has a method for creating a VersionedArtifact, which captures our dependency. There's also a factory method for creating the GroupId. To use it, we can do this:

groupId("net.vz.mongodb.jackson")
    .artifact("mongo-jackson-mapper")
    .version("1.4.1")

Already it's looking better than the XML version, but this could be done in Java too. Scala has a lot more to offer yet. As we go through the following lessons, it's important to remember that each change we make, we are still basically achieving the same as the above.

Lesson One: Scala methods can be made of operator characters

To Java developers such as myself, this is a bit weird, though in other languages this is quite common. There's some strict rules in Scala over exactly what a method name can be, but one of them is that it can be made of just operator characters (ie, +, =, *, /, %, -). So, the first change that we're going to make to our above code is to get rid of the method names. It makes things more concise, but other than that it may seem like a weird first step. Bear with me.

def groupId(groupId: String) = new GroupId(groupId)

class GroupId(val groupId: String) {
  def %(artifactId: String) = new Artifact(groupId, artifactId)
}

class Artifact(val groupId: String, val artifactId: String) {
  def %(version: String) = new VersionedArtifact(groupId, artifactId, version)
}

class VersionedArtifact(val groupId: String, val artifactId: String, val version: String) {
}

And so now our code looks like this:

groupId("net.vz.mongodb.jackson").%("mongo-jackson-mapper").%("1.4.1")

Lesson Two: Scala lets you call methods without dots or braces

There are some strict rules governing this, including what happens when obvious ambiguities arise, and you can read about that in your own research. Simply put, if a method has only one parameter, you can call it without using the dot or the braces, rather replace them with whitespace:

groupId "net.vz.mongodb.jackson" % "mongo-jackson-mapper" % "1.4.1"

This is called "infix" notation, which you may have also come across in other languages. We're looking much more concise now, yet the language is still enforcing the rules of what a dependency needs, and we still end up with a strongly typed VersionedArtifact. But we're not finished yet.

Lesson Three: Scala lets you implicitly convert any type to any other type

For me, this was the weirdest feature of Scala when I first learnt it. It sounded so cool and dangerous at the same time. Most languages, when you call a method on an object, if that method doesn't exist on that object, will throw an error. For the statically typed languages, this will be a compile error, for dynamic languages it will be a runtime error. However, Scala is a bit different. When it can't find a method on an object (at compile time, because Scala is statically typed), it has an extra step before it gives up with an error.

Scala has a concept of implicit methods (and fields, but right now let's only worry about methods). When a method that you call doesn't exist on the object that you're trying to invoke it on, it will check the current scope for any implicit methods that take the type of your object as an argument. If the implicit methods result has a method that matches the method you are trying to invoke, then Scala wraps your object in that implicit method call. Have I lost you? Let's look at some code:

implicit def groupId(groupId: String) = new GroupId(groupId)

So the only change I've made to the groupId method is that it is now implicit. When I define my dependency, I do this:

"net.vz.mongodb.jackson" % "mongo-jackson-mapper" % "1.4.1"

Notice that the call to groupId has completely disappeared. If at this point our newly learned Scala syntax is a bit overwhelming for you, maybe it would help to see what would happen if we applied the implicit rule first, before we got rid of our dots and braces, and before we changed our method names to percent signs:

"net.vz.mongodb.jackson".artifact("mongo-jackson-mapper").version("1.4.1")

What it looks like is that the String class has a method called artifact. But we know it doesn't, and the Scala compiler knows it doesn't. When it sees that, it says "are there any implicit methods available that accept a String as an argument, and return a type that has an artifact method?" And so it finds the groupId method, which matches that criteria, and converts the code to this:

groupId("net.vz.mongodb.jackson").artifact("mongo-jackson-mapper").version("1.4.1")

When I first learnt about this feature, I thought that's going to cause massive amounts of confusion. And, if used irresponsibly, it certainly can. But when used carefully, it provides a massive amount of power, for example, in creating DSLs like the one above, or for adding methods to third party/legacy libraries.

On a side note, Scala actually uses this to add functional methods to the Java collections API, by providing implicit conversion methods to convert Java collections to Scala collections. The confusion that could arise can be avoided by being careful to only import the implicit methods you need, where you need them. Scala allows you to put an import statement anywhere in code, so if you have one method in a class that needs to work with Java collections, you can import the implicit collection conversions into just that block of code. In the case of DSLs, it's usually the case that the use of these are separated from the rest of your code by normal encapsulation, so a DSL for accessing a database would only appear in a DAO, where you're expecting it to appear, so there would be no confusion.

Putting it into context

There is one last step, and that is to show what we actually do with the dependency we've declared. In maven, our XML ends up being part of the rest of the overly verbose POM file. In SBT, there is a file called build.sbt that contains snippets of Scala, which build up the configuration. Available to the snippets are certain variables, for example, libraryDependencies, which contain the configuration. So adding my dependency to my SBT project means adding a line to the build.sbt file, like this:

libraryDependencies += "net.vz.mongodb.jackson" % "mongo-jackson-mapper" % "1.4.1"

You may notice here that there is another use of infix operator named methods, += is a method defined in mutable collections for adding elements. Multiple dependencies can be added more tersely like this:

libraryDependencies ++= Seq(
    "net.vz.mongodb.jackson" % "mongo-jackson-mapper" % "1.4.1",
    "net.vz.mongodb.jackson" % "play-mongo-jackson-mapper" % "1.0.0"
)

So there we have it. We have a very terse way of specifying dependencies, that is strongly typed, and so validated at compile time. I hope that the problem I have presented and solved is a problem that you have understood and can relate to, and that you've understood how Scala can be used to help solve it, and that in learning how Scala helps solve it, you've learnt and understood some important features of the Scala language that you're now keen to apply to other problems that you have.

Open Source Developers

Since everyone's doing it, I thought I'd do it too. My contribution to the What people think I do meme:

Cracking Random Number Generators - Part 4

In Part 3 of this series, we investigated the Mersenne Twister, and saw how with 624 consecutive integers obtained from it, we can predict every subsequent integer it will produce. In this part, we will look at how to calculate previous integers that it has produced. We will also discuss how we might go about determining the internal state of the Mersenne Twister if we are unable to get 624 consecutive integers.

Breaking it down

Before we go on, let's look again at the algorithm for generating the next value for state[i]:

int y = (state[i] & 0x80000000) + (state[(i + 1) % 624] & 0x7fffffff);
int next = y >>> 1;
if ((y & 1L) == 1L) {
  next ^= 0x9908b0df;
}
next ^= state[(i + 397) % 624];

We can see that there are 3 numbers from the previous state that are involved here, the old state[i], state[(i + 1) % mod 624], and state[(i + 397) % mod 624]. So, taking these 3 numbers, let's have a look at what this looks like in binary:

1. 11100110110101000100101111000001 // state[i]
2. 10101110111101011001001001011111 // state[i + 1]
3. 11101010010001001010000001001001 // state[i + 397]

// y = state[i] & 0x80000000 | state[i + 1] & 0x7fffffff
4. 10101110111101011001001001011111 // y
5. 01010111011110101100100100101111 // next = y >>> 1
6. 11001110011100100111100111110000 // next ^= 0x9908b0df
7. 00100100001101101101100110111001 // next ^= state[i + 397]

If we work backwards from i = 623, then the pieces of information from the above equation is the end result (7), state[i + 1] and state[i + 397]. Starting from the result, the easiest step to unapply is the last one, undoing an xor is as simple as applying that same xor again. So we can get from 7 to 6 by xoring with state[i + 397].

To get from 6 to 5 depends on whether y was odd, if it wasn't, then no operation was applied. But we can also see from the bitshift to the right applied from 4 to 5, that the first bit at 5 will always be 0. Additionally, the first bit of the magic number xored at 6 is 1. So, if the first bit of the number at 6 is 1, then the magic number must have been applied, otherwise it hasn't. Hence, we can conditionally unapply step 5.

At this point, we have the first bit of the old state[i] calculated, in addition to the middle 30 bits of state[i + 1] calculated. We can also infer the last bit of state[i + 1], it is the same as the last bit of y, and if y was odd, then the magic number was applied at step 6, otherwise it wasn't. We've already worked out whether the magic number was applied at step 6, so if it was, the last bit of state[i + 1] was 1, or 0 otherwise.

However, as we work backwards through the state, we will already have calculated state[i + 1], so determining its last 31 bits is not useful to us. What we really want is to determine the last 31 bits of state[i]. To do this we can apply the same transformations listed above to state[i - 1]. This will give us the last 31 bits of state[i].

Putting it all together, our algorithm looks like this:

for (int i = 623; i >= 0; i--) {
  int result = 0;
  // first we calculate the first bit
  int tmp = state[i];
  tmp ^= state[(i + 397) % 624];
  // if the first bit is odd, unapply magic
  if ((tmp & 0x80000000) == 0x80000000) {
    tmp ^= 0x9908b0df;
  }
  // the second bit of tmp is the first bit of the result
  result = (tmp << 1) & 0x80000000;

  // work out the remaining 31 bits
  tmp = state[(i - 1 + 624) % 624];
  tmp ^= state[(i + 396) % 624];
  if ((tmp & 0x80000000) == 0x80000000) {
    tmp ^= 0x9908b0df;
    // since it was odd, the last bit must have been 1
    result |= 1;
  }
  // extract the final 30 bits
  result |= (tmp << 1) & 0x7fffffff;
  state[i] = result;
}

Dealing with non consecutive numbers

What happens, if when collecting 624 numbers from the application, that some other web request comes in at the same time and obtains a number. Our 624 numbers won't be consecutive. Can we detect that, and what can we do about it?. Detecting it is simple, having collected 624, we can predict the 625th, if it doesn't match the next number, then we know we've missed some.

Finding out how many numbers we've missed is the next task. Let's say we only missed the 624th number. This is fairly straight forward to detect, we would find that our state[623] is equal to what we were expecting to be the new state[0]. We would then know that we've missed one number, and by continuing to extract numbers from the application, and comparing that with the results we were expecting, we can narrow down which one.

A generalised algorithm for doing this is beyond the scope of these blog posts. But it should be clear from the reverse engineering of the steps that if most of the values are correct, but only a few are missing, determining what they were will be a fairly simple process.

We now know how to determine the internal state and how to go backwards as well as forwards in two of the most popular PRNG algorithms. In Part 5, we will look at what developers can do to ensure that their applications are safe against PRNG attacks.

Cracking Random Number Generators - Part 3

In Part 1 and Part 2 of this series we focussed on one of the simplest PRNG's, the linear congruential PRNG. We looked at detail into Java's implementation, and then wrote algorithms to crack the seed, and to calculate previous seeds from the current seed.

Not many other languages use a linear congruential PRNG. Originally, the C and C++ implementations of rand() were very simple linear congruential PRNGs, but these days they use more complex algorithms to generate higher quality random numbers. PHPs rand() function delegates to the platforms rand() function, so depending on the platform, it may be a linear congruential PRNG.

A much more popular PRNG algorithm is the Mersenne Twister. This is used by Ruby's rand() function, Pythons random module, and PHP's mt_rand() function.

The Mersenne Twister was invented in 1997. The most common implementation of it uses an internal state of 624 32 bit words (integers). These integers are handed out sequentially, applying a transform algorithm to each one before handing them out. Once all 624 integers have been handed out, an algorithm is applied to the state, to get the next 624 integers.

Generating the next state

The algorithm for generating the next state is simple, it heavily uses bitwise operations such as shifts, masks and xors. The algorithm is as follows:

int[] state;
// Iterate through the state
for (i = 0; i < 624; i++) {
  // y is the first bit of the current number,
  // and the last 31 bits of the next number
  int y = (state[i] & 0x80000000) + (state[(i + 1) % 624] & 0x7fffffff);
  // first bitshift y by 1 to the right
  int next = y >>> 1;
  // xor it with the 397th next number
  next ^= state[(i + 397) % 624];
  // if y is odd, xor with magic number
  if ((y & 1L) == 1L) {
    next ^= 0x9908b0df;
  }
  // now we have the result
  state[i] = next;
}

Obtaining the next number

Before handing the next integer out, a transform is applied to each integer. This transform has 4 steps, each step involves xoring the number with a bit shifted and bit masked version of itself. The algorithm is as follows:

currentIndex++;
int tmp = state[currentIndex];
tmp ^= (tmp >>> 11);
tmp ^= (tmp << 7) & 0x9d2c5680;
tmp ^= (tmp << 15) & 0xefc60000;
tmp ^= (tmp >>> 18);
return tmp;

Cracking the Mersenne Twister

Since the Mersenne Twister contains 624 integers of internal state that it hands out sequentially, obtaining 624 consecutive integers from the generator is the first step in cracking it. In a common webapp, this might be done by issuing 624 requests, and recording the token. If we can then undo the transform applied to these integers, we will then have the complete internal state. Having calculated that, we can predict every subsequent number.

In order to undo the transform, we will take it one step at a time, in reverse order. The step of the transform we need to understand first is therefore:

tmp ^= (tmp >>> 18);

Let's look at what this looks like in binary:

10110111010111100111111001110010                    tmp
00000000000000000010110111010111100111111001110010  tmp >>> 18
10110111010111100101001110100101                    tmp ^ (tmp >>> 18)

As is clear from the bold bits above, we can easily work out what the first 18 bits of the previous number was, they are the same as the result. Taking those 18 bits, we can bitshift them across 18 bits, and xor them with the value, this will give us the entire number. A generalised algorithm can be written to do this for any arbitrary number of bits, and we can reuse this to reverse the first step as well:

int unBitshiftRightXor(int value, int shift) {
  // we part of the value we are up to (with a width of shift bits)
  int i = 0;
  // we accumulate the result here
  int result = 0;
  // iterate until we've done the full 32 bits
  while (i * shift < 32) {
    // create a mask for this part
    int partMask = (-1 << (32 - shift)) >>> (shift * i);
    // obtain the part
    int part = value & partMask;
    // unapply the xor from the next part of the integer
    value ^= part >>> shift;
    // add the part to the result
    result |= part;
    i++;
  }
  return result;
}

This method can be used to reverse the two right bitshift operations. The left bitshift operations also include masking the bitshifted number with a magic number, so when we unapply part as above, we have to apply the mask to it, as follows:

int unBitshiftLeftXor(int value, int shift, int mask) {
  // we part of the value we are up to (with a width of shift bits)
  int i = 0;
  // we accumulate the result here
  int result = 0;
  // iterate until we've done the full 32 bits
  while (i * shift < 32) {
    // create a mask for this part
    int partMask = (-1 >>> (32 - shift)) << (shift * i);
    // obtain the part
    int part = value & partMask;
    // unapply the xor from the next part of the integer
    value ^= (part << shift) & mask;
    // add the part to the result
    result |= part;
    i++;
  }
  return result;
}

Using these two functions, we can unapply the transformation applied to the numbers coming out of the PRNG:

int value = output;
value = unBitshiftRightXor(value, 18);
value = unBitshiftLeftXor(value, 15, 0xefc60000);
value = unBitshiftLeftXor(value, 7, 0x9d2c5680);
value = unBitshiftRightXor(value, 11);

Having applied the above algorithm to 624 consecutive numbers from the PRNG, we now have the complete state of the Mersenne Twister, and can easily determine every subsequent value. In Part 4 of the series, we'll investigate how to find previous values generated by the Mersenne Twister, and also discuss how to deal with being unable to get consecutive numbers out of the PRNG.

Cracking Random Number Generators - Part 2

In Part 1 of this series, we saw how simple it is to predict future values generated by a linear congruential PRNG. In this part, we will look at how to calculate past values generated by a linear congruential PRNG.

Undoing three simple operations

As we saw before, going forward on a pseudo random number generator involves multiplying by the multiplier, adding the addend, and then trimming the result back down to 48 bits. Surely this can be solved by a simple subtraction and divide?

It's not that simple. Let's have a look at it in binary:

Seed 1:     110011111101000111101001010011010100010001001111
Multiplier:              10111011110111011001110011001101101 *
------------------------------------------------------------
             11111011111111000111100000110010000111110100011
Addend:                                                 1011 +
------------------------------------------------------------
Seed 2:      11111011111111000111100000110010000111110101110
Addend:                                                 1011 -
------------------------------------------------------------
             11111011111111000111100000110010000111110100011
Multiplier:              10111011110111011001110011001101101 \
------------------------------------------------------------
                                               1010101110110

As you can see, we've done the multiply, the add, then subtracted the addend, and divided by the multiplier, and the result is nothing like the original seed. Why not? The problem is the bitmask we applied. When we applied the multiplier originally, it overflowed beyond 48 bits. The most significant bits were discarded, meaning that a simple divide can not get us the result. So how can we find the result?

The solution comes in going back to primary school. In primary school we learnt how to do long multiplication. Let's look at what multiplying two small binary numbers looks like, using long multiplication:

     101101
     001011 *
-----------
     101101
    101101
   000000
  101101
 000000
000000
-----------
00111110111 
     111111 &
-----------
     110111

Let's say that the first number is our multiplier, it is the known, and the second number is the unknown that we want to determine. The final number is the seed that we have, trimmed down to the number of bits, it is known, but the excess bits are not.

It should be clear from the above that there is a pattern, for each 1 in the second number, there is a corresponding first number on the line below. We can also see that if the last digit of the original number was 0, then the last digit of the new number will be 0, if it was 1, then the new number will be 1. In the above example we can determine therefore that the last digit in the original number was 1. Because we know that, subtract the multiplier from our seed, and reverse the effects of the value of the last number in the original number. Here is the result:

Result: -----1

     101101
     00101- *
-----------
    101101
   000000
  101101
 000000
000000
-----------
00111000010 
     111111 &
-----------
     000010
We can then continue the process with the next digit, if the next digit from the end in the seed is 1, then the next digit in the result is 1, otherwise its 0, and if it is 1, we subtract the multiplier, this time bitshifted across by 1:
Result: ----11

     101101
     0010-- *
-----------
   000000
  101101
 000000
000000
-----------
00101101000 
     111111 &
-----------
     101000

Continuing this process, of scanning from right to left, taking the next digit in the seed, and subtracting the effect that that digit would have had on the seed, we can determine the end result.

This only works if our multiplier is odd, which it is. If it weren't, we would only have to bitshift the seed to the left by the number of 0's on the end of the multiplier, and every bit that we shifted across would be one less bit of the final result that we could determine. So our algorithm for determining the previous seed looks like this:

long seed = currentSeed;
// reverse the addend from the seed
seed -= addend; // reverse the addend
long result = 0;
// iterate through the seeds bits
for (int i = 0; i < 48; i++)
{
    long mask = 1L << i;
    // find the next bit
    long bit = seed & mask;
    // add it to the result
    result |= bit;
    if (bit == mask)
    {
        // if the bit was 1, subtract its effects from the seed
        seed -= multiplier << i;
    }
}
System.out.println("Previous seed: " + result);

And there we have it, we can now predict all previous values from java.util.Random given the current seed.

That covers everything you need to know about hacking a linear congruential PRNG. However, most common platforms don't use such a simple PRNG as their default random number generator. In Part 3 we will look at how to hack a very popular PRNG, the Mersenne Twister.

About

Hi! My name is James Roper, and I am a software developer with a particular interest in open source development and trying new things. I program in Scala, Java, Go, PHP, Python and Javascript, and I work for Lightbend as the architect of Kalix. I also have a full life outside the world of IT, enjoy playing a variety of musical instruments and sports, and currently I live in Canberra.