diff --git a/tests/snippets/builtin_range.py b/tests/snippets/builtin_range.py index c8efb189d52..667e776a240 100644 --- a/tests/snippets/builtin_range.py +++ b/tests/snippets/builtin_range.py @@ -40,20 +40,24 @@ def assert_raises(expr, exc_type): assert_raises(lambda _: range(10).index('foo'), ValueError) # __bool__ -assert range(1).__bool__() -assert range(1, 2).__bool__() +assert bool(range(1)) +assert bool(range(1, 2)) -assert not range(0).__bool__() -assert not range(1, 1).__bool__() +assert not bool(range(0)) +assert not bool(range(1, 1)) # __contains__ -assert range(10).__contains__(6) -assert range(4, 10).__contains__(6) -assert range(4, 10, 2).__contains__(6) -assert range(10, 4, -2).__contains__(10) -assert range(10, 4, -2).__contains__(8) - -assert not range(10).__contains__(-1) -assert not range(10, 4, -2).__contains__(9) -assert not range(10, 4, -2).__contains__(4) -assert not range(10).__contains__('foo') +assert 6 in range(10) +assert 6 in range(4, 10) +assert 6 in range(4, 10, 2) +assert 10 in range(10, 4, -2) +assert 8 in range(10, 4, -2) + +assert -1 not in range(10) +assert 9 not in range(10, 4, -2) +assert 4 not in range(10, 4, -2) +assert 'foo' not in range(10) + +# __reversed__ +assert list(reversed(range(5))) == [4, 3, 2, 1, 0] +assert list(reversed(range(5, 0, -1))) == [1, 2, 3, 4, 5] diff --git a/tests/snippets/builtin_reversed.py b/tests/snippets/builtin_reversed.py new file mode 100644 index 00000000000..2bbfcb98a2c --- /dev/null +++ b/tests/snippets/builtin_reversed.py @@ -0,0 +1 @@ +assert list(reversed(range(5))) == [4, 3, 2, 1, 0] diff --git a/vm/src/builtins.rs b/vm/src/builtins.rs index 75cd7afaf24..1321a96e460 100644 --- a/vm/src/builtins.rs +++ b/vm/src/builtins.rs @@ -587,6 +587,19 @@ fn builtin_repr(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!(vm, args, required = [(obj, None)]); vm.to_repr(obj) } + +fn builtin_reversed(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!(vm, args, required = [(obj, None)]); + + match vm.get_method(obj.clone(), "__reversed__") { + Ok(value) => vm.invoke(value, PyFuncArgs::default()), + // TODO: fallback to using __len__ and __getitem__, if object supports sequence protocol + Err(..) => Err(vm.new_type_error(format!( + "'{}' object is not reversible", + objtype::get_type_name(&obj.typ()), + ))), + } +} // builtin_reversed fn builtin_round(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { @@ -692,6 +705,7 @@ pub fn make_module(ctx: &PyContext) -> PyObjectRef { ctx.set_attr(&py_mod, "property", ctx.property_type()); ctx.set_attr(&py_mod, "range", ctx.range_type()); ctx.set_attr(&py_mod, "repr", ctx.new_rustfunc(builtin_repr)); + ctx.set_attr(&py_mod, "reversed", ctx.new_rustfunc(builtin_reversed)); ctx.set_attr(&py_mod, "round", ctx.new_rustfunc(builtin_round)); ctx.set_attr(&py_mod, "set", ctx.set_type()); ctx.set_attr(&py_mod, "setattr", ctx.new_rustfunc(builtin_setattr)); diff --git a/vm/src/obj/objrange.rs b/vm/src/obj/objrange.rs index 717027403c0..221e98da506 100644 --- a/vm/src/obj/objrange.rs +++ b/vm/src/obj/objrange.rs @@ -94,6 +94,22 @@ impl RangeType { } #[inline] + pub fn reversed(&self) -> Self { + match self.step.sign() { + Sign::Plus => RangeType { + start: &self.end - 1, + end: &self.start - 1, + step: -&self.step, + }, + Sign::Minus => RangeType { + start: &self.end + 1, + end: &self.start + 1, + step: -&self.step, + }, + Sign::NoSign => unreachable!(), + } + } + pub fn repr(&self) -> String { if self.step == BigInt::one() { format!("range({}, {})", self.start, self.end) @@ -116,6 +132,11 @@ pub fn init(context: &PyContext) { context.set_attr(&range_type, "__new__", context.new_rustfunc(range_new)); context.set_attr(&range_type, "__iter__", context.new_rustfunc(range_iter)); + context.set_attr( + &range_type, + "__reversed__", + context.new_rustfunc(range_reversed), + ); context.set_attr( &range_type, "__doc__", @@ -190,6 +211,23 @@ fn range_iter(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { )) } +fn range_reversed(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!(vm, args, required = [(zelf, Some(vm.ctx.range_type()))]); + + let range = match zelf.borrow().payload { + PyObjectPayload::Range { ref range } => range.reversed(), + _ => unreachable!(), + }; + + Ok(PyObject::new( + PyObjectPayload::Iterator { + position: 0, + iterated_obj: PyObject::new(PyObjectPayload::Range { range }, vm.ctx.range_type()), + }, + vm.ctx.iter_type(), + )) +} + fn range_len(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!(vm, args, required = [(zelf, Some(vm.ctx.range_type()))]);