Learning about Function Foreign Memory with a Neural Processing Unit
We're learning about Function Foreign Memory by multiplying matrices onto a Neural Processing Unit.
Introduction
Do you know Foreign Function Memory (FFM)? It’s a part of the Java runtime that allows developers to directly manipulate virtual memory contents. It has been in development for the past few JDKs, and is now coming as a fully-fledged feature into JDK 22! We’ve been looking at this latest iteration, and how to do stuff with it.
Moreover, the tool jextract
generates Java sources from a C header, thus providing a primitive way to call functions provided by the native library.
Since we’re working on embedded development at this time, we took an example that is specific to embedded development.
We’re multiplying matrices on a Neural Processing Unit (a.k.a., NPU, basically a specialized processor).
We’ll compare the performances with similar C++ code.
As the JDK Enhancement Proposal for FFM is quite detailed, we’ll try to give you a simpler overview and examples.
What is Foreign Function Memory (FFM), and how do I use it?
/!\ As the API has been changing over the last few iterations, be sure to check the documentation for the version you plan to use: JDK22 iteration! Very basically, FFM is trying to replace JNI, with a particular focus on safety, performance and generality. So, this provides a new way to work with virtual memory contents and native libraries. The use of a bunch of native libraries is quite fashionable at the moment, particularly in the AI world.
5 minutes tutorial
Since it has a strong focus on native library linking, let’s link some libraries! Notice that the following are full examples. They are compilable and executable with a JDK 22, such as this JDK22.
Find some native code
This small program simply looks for a function, that you pass as argument, in commonly used libraries.
To do that, it first creates a linker, which is then going to search commonly used libraries via the defaultLookup()
method.
Then we can search for our native function and print its address.
import java.lang.foreign.Linker;
import java.lang.foreign.SymbolLookup;
import java.lang.foreign.MemorySegment;
import java.util.Optional;
public class FindStdLibFunction {
public static void main(String[] args) {
Linker linker = Linker.nativeLinker();
SymbolLookup stdlib = linker.defaultLookup();
Optional<MemorySegment> stdlibFunction = stdlib.find(args[0]).orElseThrow();
System.out.println(stdlibFunction.toString());
}
}
The important point to note here, is that native libraries are found by an instance of the Linker
class.
Execute some native code
More interestingly, the following example actually executes a function implemented by commonly used libraries: abs
.
This code starts by creating a Java object that knows how to execute the native function.
To create the MethodHandle
instance, it needs:
- The address of the function that we found in the previous code snippet, and
- The method type signature, that we can find in a C header or the documentation.
Afterward, we need a way to pass our arguments to the native function.
For that purpose, we use an arena.
An arena is a block of virtual memory whose contents you can modify, and that is outside the Java heap.
Inside the arena, we can allocate some memory blocks, represented by MemorySegment
.
A MemorySegment
holds a pointer to the memory block it represents, and sometimes also the size of the memory block.
Here, we allocate memory space for two 32-bits signed integers.
Then, we set their values with the method setAtIndex
.
Finally, once our function and arguments are ready, we call invokeExact
to execute the native function, and retrieve its return value.
import java.lang.foreign.Arena;
import java.lang.foreign.Linker;
import java.lang.foreign.FunctionDescriptor;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.SymbolLookup;
import java.lang.foreign.ValueLayout;
import java.lang.invoke.MethodHandle;
import java.util.Optional;
public class AbsTestRun {
public static void main(String[] args) {
MethodHandle absMethodHandle =
linker.downcallHandle(
Linker.nativeLinker().defaultLookup().find("abs").orElseThrow(),
FunctionDescriptor.of(ValueLayout.JAVA_INT, ValueLayout.JAVA_INT));
try (Arena arena = Arena.ofConfined()) {
MemorySegment ptr = arena.allocate(ValueLayout.JAVA_INT, 2);
ptr.setAtIndex(ValueLayout.JAVA_INT , 0, 10);
ptr.setAtIndex(ValueLayout.JAVA_INT , 1, -10);
int positive = (int)absMethodHandle.invokeExact(ptr.getAtIndex(ValueLayout.JAVA_INT, 0));
int negative = (int)absMethodHandle.invokeExact(ptr.getAtIndex(ValueLayout.JAVA_INT, 1));
System.out.println("abs of negative = " + negative + " positive = " + positive);
} catch(Throwable e){ e.printStackTrace(); }
}
}
The most important thing to remember here is that we have to specify the type for every operation related to the virtual memory accesses, using the layouts defined in ValueLayout
.
A more tasteful example
We implemented a micro-benchmark of matrix multiplication on a Neural Processing Unit.
RKNN-Toolkit provides a way to use Rockchip NPUs on devices that have them.
The toolkit gives us the shared library (librknnrt.so
) as well as the C header file (rknn_matmul_api.h
) in order to communicate with the NPU.
To evaluate whether that would be useful for our embedded Java, we tested it!
To do that, we used jextract
to generate a bunch of Java boilerplate code, and we used it to write a Java program that uses the NPU.
On top of that, we developed code similar to a C++ micro-benchmark implemented by Martin Chang.
Data first, code later
roc_rk3588s_pc:/data/microdoc # ./jdk/bin/java --enable-native-access=ALL-UNNAMED -cp . org.rknn.JavaMicroBenchmarks
Size = 256 seconds = 0.00126035807
Size = 512 seconds = 0.0038540170400000002
Size = 1024 seconds = 0.0115888624
Size = 2048 seconds = 0.06323490967999999
Size = 4096 seconds = 0.84754899555
Size = 8192 seconds = 14.975589004349999
roc_rk3588s_pc:/data/microdoc # ./cppMicroBenchmarks
Size = 256 seconds = 0.00027
Size = 512 seconds = 0.00088
Size = 1024 seconds = 0.0044
Size = 2048 seconds = 0.03026
Size = 4096 seconds = 0.78822
Size = 8192 seconds = 14.7231
Experimental setup
We executed two micro-benchmarks:
- A C++ micro-benchmark, which is a slightly modified version of this blog post;
- A Java micro-benchmark that is very similar to the C++ one, written from scratch, with the help of Java sources generated by
jextract
.
We run both of those micro-benchmark on the rk3588s
device through ADB.
rkn3588s
runs a 64-bits Android operating system.
Those benchmarks both allocate two square matrices of some size provided as a parameter.
The benchmark covers only the call to the framework multiplying the matrices rknn_matmul_run
.
We’re multiplying those matrices together one hundred times, before reporting the average execution time.
The point of those benchmarks is to calculate the overhead of using Java while multiplying matrices relative to the matrices size, compared to matrix multiplication speed.
So, how much does it cost me to use Java ?
The following table compiles the previous results, and compares the java results to a baseline, the C++ code.
Here’s the full java code.
Here’s the full c++ code.
Matrix size | C++ | Java | Relative performance |
---|---|---|---|
256 | 0.00027 | 0.00126 | 21.43% |
512 | 0.00088 | 0.00385 | 22.86% |
1024 | 0.00441 | 0.01159 | 38.05% |
2048 | 0.03026 | 0.06324 | 47.85% |
4096 | 0.78822 | 0.84755 | 92.99% |
8192 | 14.7231 | 14.9756 | 98.31% |
From that table, we can clearly see that the bigger the matrix is, the more time is spent in the NPU.
On small matrix multiplications, there is a heavy runtime cost. But on bigger matrices, it is barely noticeable!
Important FFM bits
We’re now explaining some details that were not obvious to us, while we were writing this blog post, and the accompanying examples and benchmarks.
Pointer operations
The data pointed to by a given MemorySegment
can be interpreted by the method get
and getAtIndex
.
You also must provide a layout, to tell the FFM framework how many bits the framework should read, and how it should interpret them.
The method getAtIndex
is useful when the data you manipulate is an array of the same data type, whereas get
makes more sense in the case of structures.
In our example, we use getAtIndex
multiple times.
There are instances of get
uses in the code generated by jextract
.
Similarly, use set
and setAtIndex
to set the bits representing your data.
There is a caveat that I want to address though, which we encountered while trying to check that our matrices were correct by printing them.
A MemorySegment
checks that memory access are in bounds, given the size it contains.
When you create a new MemorySegment
via Arena.allocate
, it comes with the correct size.
The following code works:
// Allocating a new array of floats, and setting its first element to a random float.
MemorySegment ptr = arena.allocate(ValueLayout.JAVA_FLOAT, size);
ptr.setAtIndex(ValueLayout.JAVA_FLOAT, 0, (float) Math.random() * 10);
However, calling MemorySegment.getAtIndex
(see the following code) creates a MemorySegment
with an incorrect size:
// Trying to access some row of the matrix, and then an element of that row.
MemorySegment row = matrix.getAtIndex(AddressLayout.ADDRESS, 0);
float v = row.getAtIndex(ValueLayout.JAVA_FLOAT, 0);
// An exception is thrown!
The exception is thrown because the MemorySegment
object checks that memory accesses are in bounds.
When a pointer is freshly allocated, the MemorySegment
knows its size.
However, when a MemorySegment
object is created from an arbitrary pointer, the object doesn’t know the size of the memory block that the pointer points to.
In that case, the MemorySegment
object holds a size of zero, which causes every memory access to fail bound checks.
In order to correct that previous snippet, you can do the following:
// Trying to access some row of the matrix, and then an element of that row.
MemorySegment row = matrix.getAtIndex(AddressLayout.ADDRESS, 0).reinterpret(size * Float.BYTES);
float v = row.getAtIndex(ValueLayout.JAVA_FLOAT, 0);
MemorySegment.reinterpret
sets the size of the memory block pointed to by the MemorySegment
object, and thus how much memory you’re allowed to access from that MemorySegment
.
Note that jextract
works around this issue, by providing MemorySegment
objects whose sizes are Long.MAX_VALUE
, whenever a pointer is returned from a function call.
But using that size essentially disables bound checks, which can cause unintended out-of-bounds accesses.
Layouts
The primitive types use layouts defined in ValueLayout
.
For addresses, you should use an AddressLayout
.
For more complex layouts such as structures, you might want to use a MemoryLayout
instead.
Conclusion
In this blog post you saw how to use the new feature Foreign Function Memory in JDK 22. You read some very basic examples as well as a real world example of using FFM to use an NPU. Finally, we described details that were not obvious to us.
Stay Connected
Everything On Graal VM
We illustrate how to build and debug the native image of a Java application, produced by MicroDoc GraalVM Embedded, and running on a remote embedded device. The debugging experience is illustrated inside the VSCode IDE. The embedded device runs Linux on an ARMv8-A processor (running in AArch64 mode), like a RaspberryPi 4 or 5.
For devices that are not powerful enough to run a full feature browser or Android system, MicroDoc is developing a user interface framework directly integrated with the Graal virtual machine. Why? Because the number of small devices is increasing dramatically, and although they may be small and low powered they still deserve an up-to-date user interface.
Would it be fair to say that the GraalVM is a step closer to the holy grail of polyglot programming? Yes, according to the virtual machine experts from MicroDoc. Here’s an overview of what the GraalVM can do for software developers.