Published on

Simple Analysis Pass for Rust MIR

Rust's Mid-level Intermediate Representation is known as MIR. It sits in between LLVM IR and the current HIR (high-level IR, which is roughly an abstract syntax tree) (the "low-level" IR). It is a greatly simplified version of Rust that is used for code generation, optimization, and some flow-sensitive safety checks, mostly those relating to the borrow checker. The borrow checker is followed by a number of optimization stages in Rust.

These passes are often added to a list of passes in the run_optimization_passes() function and written in-tree for the main rust codebase.

We do numerous optimizations on the MIR because it is still generic, and doing so enhances the code we generate later while also accelerating compilation. It is simpler to perform various optimizations at the MIR level than at the LLVM-IR level since MIR is a higher level (and more generic) representation. For instance, it appears that LLVM is unable to optimise the pattern that the simplify_try mir-opt seeks out.

However, we can also write some simpler analysis passes out of tree using the same crates that allow us to get the MIR at various stages of optimization. Let's try writing a simple analysis pass where we print out the information related to the terminator kinds of the basic blocks.

Let’s start by importing the relevant rustc crates that allow us to query the data structures and elements present at various stages of compilation. To do this, we need to use rustup to modify a toolchain's installed components. We will be using the nightly toolchain, as these internal compiler APIs are intrinsically unstable. We can use the following command to switch to the nightly toolchain for the current directory only.

rustup override set nightly

For more information on switching between toolchains, you can take a look here.

Then, to add the rustc crates, the following command is used. As mentioned earlier, they basically enable us get rustc as a lib.

rustup component add rustc-dev llvm-tools-preview

We will be utilizing the rustc_driver crate to hook into the compilation process and insert our functionality. We will be using the Callback trait, which has methods that are able to be called after parsing, expansion, and analysis. If we implement this trait in a custom callback function, we can then use that function as an argument to rustc_driver::RunCompiler so that it can be called during compilation.

Let’s start importing the crates. To reiterate, these APIs are unstable and should only be used when dealing with compiler internals.

#![feature(rustc_private)]
extern crate rustc_ast;
extern crate rustc_driver;
extern crate rustc_hir;
extern crate rustc_interface;
extern crate rustc_middle;

Now we can get started by initializing our Callback function and a RunCompiler instance. In our main function, we can also point to a file that we want to analyze.

// main.rs

// struct for our callback function
struct MyCallback;

// init trait
impl MyCallback {
    fn new() -> MyCallback {
        MyCallback {}
    }
}

// implementing the Callback trait for our function
impl rustc_driver::Callbacks for MyCallback {}

fn main() {
    let mut callbacks = MyCallback::new();
    let rustc_args = vec!["run".to_string(), "test.rs".to_string()];
    let run_compiler = rustc_driver::RunCompiler::new(&rustc_args, &mut callbacks);
    let _ = run_compiler.run();
}

In our Callback function, we can now override the after_analysis method for the Callbacks trait with the functionality for our pass. The result type should be a Compilation status, in this case just continuing the flow.

impl rustc_driver::Callbacks for MyCallback {}
impl rustc_driver::Callbacks for MyCallback {
    fn after_analysis<'tcx>(
        &mut self,
        handler: &EarlyErrorHandler,
        compiler: &Compiler,
        queries: &'tcx Queries<'tcx>,
    ) -> Compilation {
		Compilation::Continue
    }

Since we are looking for the terminator kinds for each basic block, we can collect the data for the actual statements of each basic block, as well as the ends. We can traverse the HIR, and get the MIR for each block and do a pattern match.

impl rustc_driver::Callbacks for MyCallback {
    fn after_analysis<'tcx>(
        &mut self,
        handler: &EarlyErrorHandler,
        compiler: &Compiler,
        queries: &'tcx Queries<'tcx>,
    ) -> Compilation {
        queries.global_ctxt().unwrap().enter(|tcx| {
            let hir = tcx.hir();
            for id in hir.items() {
                if tcx.is_mir_available(id.owner_id.def_id) {
                    let mir_body = tcx.optimized_mir(id.owner_id.def_id);
                    let out = mir_body
                        .basic_blocks
                        .iter_enumerated()
                        .map(|(bb, data)| {
                            let term = &data.terminator();
                            let kind = &term.kind;
                            let sp = format!("{:?}", &data.statements);
                            match kind {
                                TerminatorKind::Assert { target, .. }
                                | TerminatorKind::Call {
                                    target: Some(target),
                                    ..
                                }
                                | TerminatorKind::Drop { target, .. }
                                | TerminatorKind::FalseEdge {
                                    real_target: target,
                                    ..
                                }
                                | TerminatorKind::FalseUnwind {
                                    real_target: target,
                                    ..
                                }
                                | TerminatorKind::Goto { target }
                                | TerminatorKind::InlineAsm {
                                    destination: Some(target),
                                    ..
                                }
                                | TerminatorKind::Yield { resume: target, .. } => {
                                    format!("{} are the statements for {:?}:{} -> {:?}\n", sp, bb, term_type(kind), target)
                                }
                                TerminatorKind::SwitchInt { targets, .. } => {
                                    format!("{} are the statements for {:?}:{} -> {:?}\n", sp, bb, term_type(kind), targets)
                                }
                                _ => format!("{} are the statements for {:?}:{}\n", sp, bb, term_type(kind)),
                            }
                        })
                        .collect::<Vec<_>>();

                    println!("basic_blocks info \n");
                    for element in out.iter() {
                        println!("{}", element);
                    }
                }
            }
        });

        Compilation::Continue
    }
}

fn term_type(kind: &TerminatorKind<'_>) -> &'static str {
    match kind {
        TerminatorKind::Goto { .. } => "Goto",
        TerminatorKind::SwitchInt { .. } => "SwitchInt",
        TerminatorKind::Resume => "Resume",
        TerminatorKind::Return => "Return",
        TerminatorKind::Unreachable => "Unreachable",
        TerminatorKind::Drop { .. } => "Drop",
        TerminatorKind::Call { .. } => "Call",
        TerminatorKind::Assert { .. } => "Assert",
        TerminatorKind::Yield { .. } => "Yield",
        TerminatorKind::GeneratorDrop => "GeneratorDrop",
        TerminatorKind::FalseEdge { .. } => "FalseEdge",
        TerminatorKind::FalseUnwind { .. } => "FalseUnwind",
        TerminatorKind::InlineAsm { .. } => "InlineAsm",
        TerminatorKind::Terminate { .. } => "Terminate",
    }
}

We can test it on any Rust file.

use std::ptr;

extern "C" {
    fn write(fd: u32, buf: *const char, size: usize) -> u32;
    fn read(fd: u32, buf: *const char, size: usize) -> u32;
}

struct Data {
    val: Vec<char>,
    len: usize,
}

macro_rules! say_hello {
    // `()` indicates that the macro takes no argument.
    () => {
        // The macro will expand into the contents of this block.
        println!("Hello!");
    };
}

fn main() {
    let mut data = Data {
        val: vec!['a', 'b', 'c', '\n'],
        len: 3,
    };
    // say_hello!();
    let x = unsafe { write(2, data.val.as_ptr(), 10) };
    data.val.push('d');
    unsafe { read(0, data.val.as_ptr(), 4) };
}

This is the output:

[_4 = SizeOf([char; 4]), _5 = AlignOf([char; 4])] are the statements for bb0:Call -> bb1

[_7 = ShallowInitBox(move _6, [char; 4]), _16 = (((_7.0: std::ptr::Unique<[char; 4]>).0: std::ptr::NonNull<[char; 4]>).0: *const [char; 4]), _17 = _16 as *const () (PtrToPtr), _18 = _17 as usize (Transmute), _19 = AlignOf(char), _20 = Sub(_19, const 1_usize), _21 = BitAnd(_18, _20), _22 = Eq(_21, const 0_usize)] are the statements for bb1:Assert -> bb11

[_1 = Data { val: move _2, len: const 3_usize }, _10 = &(_1.0: std::vec::Vec<char>)] are the statements for bb2:Call -> bb3

[] are the statements for bb3:Call -> bb4

[_12 = &mut (_1.0: std::vec::Vec<char>)] are the statements for bb4:Call -> bb5

[_15 = &(_1.0: std::vec::Vec<char>)] are the statements for bb5:Call -> bb6

[] are the statements for bb6:Call -> bb7

[] are the statements for bb7:Drop -> bb8

[] are the statements for bb8:Return

[] are the statements for bb9:Drop -> bb10

[] are the statements for bb10:Resume

[(*_16) = [const 'a', const 'b', const 'c', const '\n'], _3 = move _7 as std::boxed::Box<[char]> (PointerCoercion(Unsize))] are the statements for bb11:Call -> bb2

In a further post, we can look at more complex analysis passes, starting with stack safety analysis.