diff --git a/src/gas.rs b/src/gas.rs index 86ffc96..b3d377e 100644 --- a/src/gas.rs +++ b/src/gas.rs @@ -1,3 +1,9 @@ +//! This module is used to instrument a Wasm module with gas metering code. +//! +//! The primary public interface is the `inject_gas_counter` function which transforms a given +//! module into one that charges gas for code to be executed. See function documentation for usage +//! and details. + use std::vec::Vec; use parity_wasm::{elements, builder}; @@ -185,10 +191,37 @@ pub fn inject_counter( Ok(()) } -/// Injects gas counter. +/// Transforms a given module into one that charges gas for code to be executed by proxy of an +/// imported gas metering function. /// -/// Can only fail if encounters operation forbidden by gas rules, -/// in this case it returns error with the original module. +/// The output module imports a function "gas" from the module "env" with type signature +/// [i32] -> []. The argument is the amount of gas required to continue execution. The external +/// function is meant to keep track of the total amount of gas used and trap or otherwise halt +/// execution of the runtime if the gas usage exceeds some allowed limit. +/// +/// The calls to charge gas are inserted at the beginning of every block of code. A block is +/// defined by `block`, `if`, `else`, `loop`, and `end` boundaries. Blocks form a nested hierarchy +/// where `block`, `if`, `else`, and `loop` begin a new nested block, and `end` and `else` mark the +/// end of a block. The gas cost of a block is determined statically as 1 plus the gas cost of all +/// instructions directly in that block. Each instruction is only counted in the most deeply +/// nested block containing it (ie. a block's cost does not include the cost of instructions in any +/// blocks nested within it). The cost of the `begin`, `if`, and `loop` instructions is counted +/// towards the block containing them, not the nested block that they open. There is no gas cost +/// added for `end`/`else`, as they are pseudo-instructions. The gas cost of each instruction is +/// determined by a `rules::Set` parameter. At the beginning of each block, this procedure injects +/// new instructions to call the "gas" function with the gas cost of the block as an argument. +/// +/// Additionally, each `memory.grow` instruction found in the module is instrumented to first make +/// a call to charge gas for the additional pages requested. This cannot be done as part of the +/// block level gas charges as the gas cost is not static and depends on the stack argument to +/// `memory.grow`. +/// +/// The above transformations are performed for every function body defined in the module. This +/// function also rewrites all function indices references by code, table elements, etc., since +/// the addition of an imported functions changes the indices of module-defined functions. +/// +/// The function fails if the module contains any operation forbidden by gas rule set, returning +/// the original module as an Err. pub fn inject_gas_counter(module: elements::Module, rules: &rules::Set) -> Result { @@ -212,7 +245,7 @@ pub fn inject_gas_counter(module: elements::Module, rules: &rules::Set) let mut module = mbuilder.build(); // calculate actual function index of the imported definition - // (substract all imports that are NOT functions) + // (subtract all imports that are NOT functions) let gas_func = module.import_count(elements::ImportCountType::Function) as u32 - 1; let total_func = module.functions_space() as u32; @@ -244,6 +277,8 @@ pub fn inject_gas_counter(module: elements::Module, rules: &rules::Set) } }, &mut elements::Section::Element(ref mut elements_section) => { + // Note that we do not need to check the element type referenced because in the + // WebAssembly 1.0 spec, the only allowed element type is funcref. for ref mut segment in elements_section.entries_mut() { // update all indirect call addresses initial values for func_index in segment.members_mut() { diff --git a/src/stack_height/mod.rs b/src/stack_height/mod.rs index 209ba53..6674717 100644 --- a/src/stack_height/mod.rs +++ b/src/stack_height/mod.rs @@ -39,7 +39,7 @@ //! //! All values are treated equally, as they have the same size. //! -//! The rationale for this it makes it possible to use this very naive wasm executor, that is: +//! The rationale is that this makes it possible to use the following very naive wasm executor: //! //! - values are implemented by a union, so each value takes a size equal to //! the size of the largest possible value type this union can hold. (In MVP it is 8 bytes) @@ -93,35 +93,20 @@ mod thunk; pub struct Error(String); pub(crate) struct Context { - stack_height_global_idx: Option, - func_stack_costs: Option>, + stack_height_global_idx: u32, + func_stack_costs: Vec, stack_limit: u32, } impl Context { /// Returns index in a global index space of a stack_height global variable. - /// - /// Panics if it haven't generated yet. fn stack_height_global_idx(&self) -> u32 { - self.stack_height_global_idx.expect( - "stack_height_global_idx isn't yet generated; - Did you call `inject_stack_counter_global`", - ) + self.stack_height_global_idx } /// Returns `stack_cost` for `func_idx`. - /// - /// Panics if stack costs haven't computed yet or `func_idx` is greater - /// than the last function index. fn stack_cost(&self, func_idx: u32) -> Option { - self.func_stack_costs - .as_ref() - .expect( - "func_stack_costs isn't yet computed; - Did you call `compute_stack_costs`?", - ) - .get(func_idx as usize) - .cloned() + self.func_stack_costs.get(func_idx as usize).cloned() } /// Returns stack limit specified by the rules. @@ -142,13 +127,11 @@ pub fn inject_limiter( stack_limit: u32, ) -> Result { let mut ctx = Context { - stack_height_global_idx: None, - func_stack_costs: None, + stack_height_global_idx: generate_stack_height_global(&mut module), + func_stack_costs: compute_stack_costs(&module)?, stack_limit, }; - generate_stack_height_global(&mut ctx, &mut module); - compute_stack_costs(&mut ctx, &module)?; instrument_functions(&mut ctx, &mut module)?; let module = thunk::generate_thunks(&mut ctx, module)?; @@ -156,7 +139,7 @@ pub fn inject_limiter( } /// Generate a new global that will be used for tracking current stack height. -fn generate_stack_height_global(ctx: &mut Context, module: &mut elements::Module) { +fn generate_stack_height_global(module: &mut elements::Module) -> u32 { let global_entry = builder::global() .value_type() .i32() @@ -168,10 +151,7 @@ fn generate_stack_height_global(ctx: &mut Context, module: &mut elements::Module for section in module.sections_mut() { if let elements::Section::Global(ref mut gs) = *section { gs.entries_mut().push(global_entry); - - let stack_height_global_idx = (gs.entries().len() as u32) - 1; - ctx.stack_height_global_idx = Some(stack_height_global_idx); - return; + return (gs.entries().len() as u32) - 1; } } @@ -179,25 +159,26 @@ fn generate_stack_height_global(ctx: &mut Context, module: &mut elements::Module module.sections_mut().push(elements::Section::Global( elements::GlobalSection::with_entries(vec![global_entry]), )); - ctx.stack_height_global_idx = Some(0); + 0 } /// Calculate stack costs for all functions. /// /// Returns a vector with a stack cost for each function, including imports. -fn compute_stack_costs(ctx: &mut Context, module: &elements::Module) -> Result<(), Error> { +fn compute_stack_costs(module: &elements::Module) -> Result, Error> { let func_imports = module.import_count(elements::ImportCountType::Function); - let mut func_stack_costs = vec![0; module.functions_space()]; - // TODO: optimize! - for (func_idx, func_stack_cost) in func_stack_costs.iter_mut().enumerate() { - // We can't calculate stack_cost of the import functions. - if func_idx >= func_imports { - *func_stack_cost = compute_stack_cost(func_idx as u32, &module)?; - } - } - ctx.func_stack_costs = Some(func_stack_costs); - Ok(()) + // TODO: optimize! + (0..module.functions_space()) + .map(|func_idx| { + if func_idx < func_imports { + // We can't calculate stack_cost of the import functions. + Ok(0) + } else { + compute_stack_cost(func_idx as u32, &module) + } + }) + .collect() } /// Stack cost of the given *defined* function is the sum of it's locals count (that is, diff --git a/src/stack_height/thunk.rs b/src/stack_height/thunk.rs index f4c3b31..e3edb30 100644 --- a/src/stack_height/thunk.rs +++ b/src/stack_height/thunk.rs @@ -21,9 +21,9 @@ pub(crate) fn generate_thunks( ctx: &mut Context, module: elements::Module, ) -> Result { - // First, we need to collect all function indicies that should be replaced by thunks + // First, we need to collect all function indices that should be replaced by thunks - // Function indicies which needs to generate thunks. + // Function indices which needs to generate thunks. let mut need_thunks: Vec = Vec::new(); let mut replacement_map: Map = { @@ -38,11 +38,11 @@ pub(crate) fn generate_thunks( let start_func_idx = module .start_section(); - let exported_func_indicies = exports.iter().filter_map(|entry| match *entry.internal() { + let exported_func_indices = exports.iter().filter_map(|entry| match *entry.internal() { Internal::Function(ref function_idx) => Some(*function_idx), _ => None, }); - let table_func_indicies = elem_segments + let table_func_indices = elem_segments .iter() .flat_map(|segment| segment.members()) .cloned(); @@ -50,7 +50,7 @@ pub(crate) fn generate_thunks( // Replacement map is at least export section size. let mut replacement_map: Map = Map::new(); - for func_idx in exported_func_indicies.chain(table_func_indicies).chain(start_func_idx.into_iter()) { + for func_idx in exported_func_indices.chain(table_func_indices).chain(start_func_idx.into_iter()) { let callee_stack_cost = ctx.stack_cost(func_idx).ok_or_else(|| { Error(format!("function with idx {} isn't found", func_idx)) })?; diff --git a/tests/diff.rs b/tests/diff.rs index e19f6b7..9026fbb 100644 --- a/tests/diff.rs +++ b/tests/diff.rs @@ -122,4 +122,5 @@ mod gas { def_gas_test!(simple); def_gas_test!(start); def_gas_test!(call); + def_gas_test!(branch); } diff --git a/tests/expectations/gas/branch.wat b/tests/expectations/gas/branch.wat new file mode 100644 index 0000000..56c5ecf --- /dev/null +++ b/tests/expectations/gas/branch.wat @@ -0,0 +1,29 @@ +(module + (type (;0;) (func (result i32))) + (type (;1;) (func (param i32))) + (import "env" "gas" (func (;0;) (type 1))) + (func (;1;) (type 0) (result i32) + (local i32 i32) + i32.const 3 + call 0 + block ;; label = @1 + i32.const 17 + call 0 + i32.const 0 + set_local 0 + i32.const 1 + set_local 1 + get_local 0 + get_local 1 + tee_local 0 + i32.add + set_local 1 + i32.const 1 + br_if 0 (;@1;) + get_local 0 + get_local 1 + tee_local 0 + i32.add + set_local 1 + end + get_local 1)) diff --git a/tests/fixtures/gas/branch.wat b/tests/fixtures/gas/branch.wat new file mode 100644 index 0000000..a3b18ce --- /dev/null +++ b/tests/fixtures/gas/branch.wat @@ -0,0 +1,27 @@ +(module + (func $fibonacci_with_break (result i32) + (local $x i32) (local $y i32) + + (block $unrolled_loop + (set_local $x (i32.const 0)) + (set_local $y (i32.const 1)) + + get_local $x + get_local $y + tee_local $x + i32.add + set_local $y + + i32.const 1 + br_if $unrolled_loop + + get_local $x + get_local $y + tee_local $x + i32.add + set_local $y + ) + + get_local $y + ) +) diff --git a/tests/fixtures/gas/call.wat b/tests/fixtures/gas/call.wat index a1ca8fb..ff08b2f 100644 --- a/tests/fixtures/gas/call.wat +++ b/tests/fixtures/gas/call.wat @@ -2,7 +2,6 @@ (func $add_locals (param $x i32) (param $y i32) (result i32) (local $t i32) - ;; This looks get_local $x get_local $y call $add diff --git a/tests/fixtures/gas/ifs.wat b/tests/fixtures/gas/ifs.wat index 4e74f50..a2dd3b8 100644 --- a/tests/fixtures/gas/ifs.wat +++ b/tests/fixtures/gas/ifs.wat @@ -1,9 +1,9 @@ (module (func (param $x i32) (result i32) (if (result i32) - (i32.const 1) - (i32.add (get_local $x) (i32.const 1)) - (i32.popcnt (get_local $x)) + (i32.const 1) + (then (i32.add (get_local $x) (i32.const 1))) + (else (i32.popcnt (get_local $x))) ) ) ) diff --git a/tests/fixtures/gas/simple.wat b/tests/fixtures/gas/simple.wat index cee2b6b..1f15d5f 100644 --- a/tests/fixtures/gas/simple.wat +++ b/tests/fixtures/gas/simple.wat @@ -1,9 +1,11 @@ (module (func (export "simple") (if (i32.const 1) - (loop - i32.const 123 - drop + (then + (loop + i32.const 123 + drop + ) ) ) )