A More Perfect Union

Learn Rust the Dangerous Way, Part 4

(Series Overview)

In part 3 we found that our use of uninitialized memory was a premature optimization that didn't actually improve performance. This left us with only one remaining unsafe function, but, boy, is it a doozy.

In this part, I'll begin the process of corralling its unsafe optimizations into more clearly safe code, by replacing arbitrary pointer casting with a lightweight abstraction.

Refresher on type punning in advance

advance is the hottest function in our program, called repeatedly to step the solar system simulation forward by tiny increments. Appropriately, it's the most heavily optimized code in the program. Part of the reason for its speed is that it uses SSE for most of its math, allowing it to perform roughly twice as many floating-point operations in a given period of time.

But advance is the only code that uses SSE. In particular, advance acts on data structures that are defined using simple f64 instead of vectors. To bridge that gap, advance arranges the f64s in memory so that they're aligned like 128-bit vectors, and then reinterprets them as a different type at basically no runtime cost, by casting the type of a pointer.

This is a powerful tool. It's also a fantastic way to shoot yourself in the foot if you're not careful! If the two types have different sizes, if the alignment is wrong, if padding appears in different places within the types, we've got undefined behavior in both C and Rust. Best case, that's a portability issue when we want to target something like ARM; worst case, it's an intermittent crash bug that hides a security exploit.

And so, doing this sort of thing in Rust is unsafe.

As of the end of part 3, we had this program:

It contains two cases of pointer casting. Both look similar; here's one:

position_Delta[m] =
    *(&position_Deltas[m].0 as *const f64 as *const __m128d).add(i);

This is loading up the __m128d array position_Delta with copies of pairs of f64 taken from position_Deltas. (The names are confusingly close; if this were my program, I would probably change one or the other.)

Let's pick apart this line.

  1. &position_Deltas[m].0 takes the address of position_Deltas[m].0, as a reference.

  2. as *const f64 converts this reference to a raw pointer. The result is now unsafe to dereference, but in exchange, we gain the ability to do dangerous things to it, like pointer arithmetic — or changing its type.

  3. as *const __m128d does the latter, by converting it to a pointer to vectors.

  4. *(...).add(i) performs C-style unchecked array indexing. Because the pointer is now a pointer to vectors, this loads a vector.

Is it safe?

Here's an aside on when unsafe is safe. This is a topic we'll keep returning to for the rest of the series.

The unsafe code above can be safe, if we uphold the following rules:

  1. The array position_Deltas[m].0 must be aligned like an __m128d, so that we don't attempt an unaligned load of a vector. Accessing memory at the wrong alignment is undefined behavior in both C and Rust, and it can also hurt performance. We have ensured that the alignment is right using the Align16 struct from part 1; that struct is the reason why we need the trailing .0.

  2. The array must also be a whole number of __m128ds in length. In this case, because an __m128d consists of two f64s, that means the array length must be even. We've also ensured this, via the ROUNDED_INTERACTIONS_COUNT size. (The "rounded" here means "round up to the next even number.")

  3. i must not stray outside the bounds of the array of vectors, meaning it must stay between 0 and ROUNDED_INTERACTIONS_COUNT / 2 - 1, inclusive. The for loop surrounding the code ensures this.

  4. The in-memory representation of any two f64s must be a valid __m128d, too. We get this one for free, due to the definition of the types: __m128d is literally a pair of f64s.

  5. The contents of position_Deltas must be initialized memory. In the case of this code fragment, position_Deltas gets initialized just above, so this holds.

Like the equivalent code in C, this code is unsafe in that it does things that can potentially violate memory safety if used incorrectly, but it is also safe in that the code isn't currently violating memory safety because we're careful how we use it. Of course, one of the risks of unsafe is that a small change somewhere else in the program breaks the code's rules. For instance, if we mess up Align16's alignment property during maintenance, this code is now relying on undefined behavior; it might work on x86, possibly for quite a while, but is sensitive to how the linker decides to lay out memory, and will eventually break.

Type punning via union

The fundamental thing we're doing here — treating memory first as one type, then as another, with no conversions — is unsafe, so we won't be able to do this task in 100% safe code. Instead, let's think about how we can make the unsafe code more compact and robust.

When we need to interpret some memory as either one type or another, depending on context, what tools do we have? We can cast pointers, yes, but we can also use unions.

A union in Rust is basically equivalent to a union in C: it looks like a struct, but all of its fields exist simultaneously, superimposed over one another. For instance, if we wanted some bits in memory that we sometimes interpret as a u32 and sometimes as an f32, we would write:

union IntOrFloat {
    i: u32,
    f: f32,
}

Since our program uses fixed-length arrays of f64 that are sometimes interpreted as arrays of __m128d, we could use a union like this:

#[repr(C)]
union Interactions {
    scalars: [f64; ROUNDED_INTERACTIONS_COUNT],
    vectors: [__m128d; ROUNDED_INTERACTIONS_COUNT / 2],
}

Accessing vectors[i] reads the same memory as the combination of scalars[2 * i] and scalars[2 * i + 1].

This has a subtle advantage over what we did before: a union inherits the alignment of its most strictly aligned member. So our Interactions type is automatically aligned like an __m128d, and we don't need a separate #[repr(align(16))] attribute like we did before! (I did slap on a #[repr(C)], to specify that I want the type laid out literally; I'm not sure that it makes any difference in current versions of Rust, but it's the right thing to do.)

If we replaced our old Align16 stuff with this union, we could rewrite the line from above as follows:

position_Delta[m] = position_Deltas[m].vectors[i];

Because we no longer have to go through raw pointers, we have the option of using checked square-bracket indexing; in this case, i is clearly in bounds because of how the loop is structured, so actual runtime bounds checks are unlikely. (I'd try it and see, personally, and that's what we'll do here. We'll measure the effect later.)

Despite using checked indexing, this line is still unsafe, and the reason is subtle. In Rust, accesses to the fields of a union are always unsafe.

Is this overkill? Technically, yes. You could imagine that Rust might have an exception for types where all bit patterns are valid (like f64), as opposed to types where random bits could produce dangerous values (like a pointer type). Currently, Rust is conservative on this, and all union accesses are unsafe.

So why is it a good thing to replace one bit of unsafe code with another? It comes down to...

How actual Rust programs use unsafe

In practice, most Rust programs don't directly use unsafe, but they do use it — hidden in safe abstractions, like library functions. You can use Vec from the standard library without writing any unsafe code, for example, but you cannot implement Vec without unsafe code1.

1

In safe languages without unsafe, like Java or Go, reading through the standard library will eventually reveal magical types that exist in the language, but can't be expressed in the language (Java arrays, Go collections). There are very few of these in C — the standard library really is written in C — and likewise in Rust. I'm a big fan of this philosophy in both languages.

When applying unsafe in Rust, our goal is to produce abstractions like Vec: we get the dangerous bits right under the hood, in a way that the caller/user can't mess up in safe code.

Let's provide a safe API to our union.

As you probably noticed in the Rust book, we add method-style functions to types using impl blocks. While the book usually uses them for structs, they work just fine for unions too. We can use an impl block to add safe accessors for the otherwise-unsafe union members.

// Our union type, as seen above, reproduced here for your reference
#[repr(C)]
union Interactions {
    scalars: [f64; ROUNDED_INTERACTIONS_COUNT],
    vectors: [__m128d; ROUNDED_INTERACTIONS_COUNT / 2],
}

impl Interactions {
    /// Returns a reference to the storage as `f64`s.
    pub fn as_scalars(&mut self) -> &mut [f64; ROUNDED_INTERACTIONS_COUNT] {
        // Safety: the in-memory representation of `f64` and `__m128d` is
        // compatible, so accesses to the union members is safe in any
        // order.
        unsafe {
            &mut self.scalars
        }
    }

    /// Returns a reference to the storage as `__m128d`s.
    pub fn as_vectors(&mut self)
        -> &mut [__m128d; ROUNDED_INTERACTIONS_COUNT/2]
    {
        // Safety: the in-memory representation of `f64` and `__m128d` is
        // compatible, so accesses to the union members is safe in any
        // order.
        unsafe {
            &mut self.vectors
        }
    }
}

"Accessors" are a common thing to see in OO languages, but they're less common in Rust2. In all languages, using an accessor instead of directly poking at the innards of a type is a way of hiding or abstracting something, like the way data is actually stored. In Rust, accessors also hide something else: unsafe code.

2

When I'm teaching Rust to people with OO backgrounds, I actively discourage them from writing accessor methods (getters, setters) for structs. Languages are different, and what's "good technique" in Java can tie you in knots in Rust, because of how accessors interact with Rust's borrow checker. In your case, your C experience will better equip you for success here.

The impl block above is doing something that's common in very low level Rust code: it's providing a safe API to an unsafe operation. The details vary, but the overall pattern always goes something like this:

  1. We've got a chunk of unsafe code, which we've convinced ourselves can be used safely in some circumstances.

  2. We add a comment, by convention3 starting with something like Safety:, explaining what those circumstances are. Basically, we're explaining to our future selves why, today, we thought this was going to be safe.

  3. We write a wrapper function that ensures that those requirements are met, and do not mark it unsafe, so that safe code can call it.

  4. We expose the wrapper function to code that can't directly muck with the bits we're obscuring — here, it's pub, while the union members are not, so code outside this file has to use the accessors instead of directly accessing the members.

3

This documentation convention for unsafe blocks is universal enough in the Rust community that the linter, clippy, now expects it.

In this case, meeting the requirements of the unsafe blocks is trivial: because the union maintains the required alignment, and the arrays are the same size (in bytes, not elements), accessing either member through a reference is always safe. If either of those things were not true, we might need to write some checking code or some asserts.

And how many places in the code do we need to look to convince ourselves that those properties are true? More than one place: we need to read the union, we need to check the values of the constants (in particular, ROUNDED_INTERACTIONS_COUNT had better be even), we need to think about how __m128d is laid out in memory. We still can't rely entirely on local reasoning, though at least we've gathered all the related bits in one place.

But all the rest of the code, using the safe API, can reason locally. There is no situation in which as_scalars() is dangerous to safe code4. When we're writing or reading code elsewhere in the file, we can skim past as_scalars() or as_vectors() without stopping to think. And that's nice.

This is typical of a safe API around unsafe code in Rust. You typically can't use local reasoning on unsafe code except in trivial cases — you have to reason at the module/file level — but the safe wrappers ensure that you can use local reasoning on the rest of the code.

This case is simple enough (just a two-field union) that writing safe wrappers might feel like boilerplate, and in a way, it is. But as the unsafe operations you're wrapping grow more complex, safe wrappers become more important.

Now, a caveat. Because the union is defined in the same file as advance, putting pub on the accessors doesn't actually do anything — any code in the same file can access the union members directly. If this were truly idiomatic Rust code5, we would separate the union into a separate module, which would require the rest of our code to only use the pub features, i.e. the accessors. For the purposes of this tutorial, I'm keeping everything in one file.

4

as_scalars() can technically be dangerous in safe code if it has been handed bogus unaligned pointers in the guise of references, for example — but only unsafe code can create such bogus references. In general, when someone says "safe Rust code can't do X," they mean "it can't do X unless some buggy unsafe code forces it to."

5

Pedantic note: if this were really idiomatic Rust code, there would be both shared & and exclusive &mut versions of the accessors. If that sentence didn't mean anything to you, don't worry, you'll learn it later.

Using the safe API

Our accessors return array references, which we first saw in part 2. Initializing our array as f64s now looks like this (note the parentheses in as_scalars()):

position_Deltas[m].as_scalars()[k] =
    bodies[i].position[m] - bodies[j].position[m];

And using its contents as vectors now looks like this:

position_Delta[m] = position_Deltas[m].as_vectors()[i];

Evaluation

Replacing all pointer-casting with uses of the union and its accessors gives this program:

It compiles to exactly the same size as the program that used pointer casting:

$ size nbody-3 nbody-4
   text	   data	    bss	    dec	    hex	filename
 265640	  10332	   7432	 283404	  4530c	nbody-3
 265640	  10332	   7432	 283404	  4530c	nbody-4

(There are some differences in the machine code, but they're not material.)

And the performance is the same:

CommandMean [s]Min [s]Max [s]Ratio
./nbody.clang-8.bench 500000005.277 ± 0.0075.2585.2821.00x
./nbody-1.bench 500000005.123 ± 0.0245.0955.1610.97x
./nbody-2.bench 500000005.101 ± 0.0055.0935.1070.97x
./nbody-3.bench 500000005.103 ± 0.0025.1005.1050.97x
./nbody-4.bench 500000005.104 ± 0.0025.1015.1070.97x

Which makes sense, if you think about it — the union is describing the same thing as the pointer casting code, only in a way that let us centralize the unsafe bits — but it's nice to see it work out that way in practice.

Let me be very clear about something: This change would also work just fine in C, and is in fact how I would have written the C code in the first place. Unions are a more specific and explicit way of treating memory as two different types, and are much harder to mess up than arbitrary pointer arithmetic and casting. Rust further nudges us toward the union approach by making it easier to type and wrap in a safe API.

Next steps

We've dealt with the most obvious unsafe code in the advance function, and centralized its unsafe bits into a difficult-to-abuse set of wrappers. There are still two different flavors of unsafe happening in the function, however, and both are kind of subtle.

In part 5, we'll fix one and fence in the other, producing a safe version of advance. Stay tuned!