Commit 868c9b74 by dongshufeng

refactor(all): remove some dependencies

parent c36d4633
......@@ -246,7 +246,7 @@ dependencies = [
"log",
"mems",
"nalgebra",
"ndarray",
"ndarray 0.15.6",
"num-complex",
"serde_json",
]
......@@ -273,7 +273,7 @@ dependencies = [
"eig-domain",
"log",
"mems",
"ndarray",
"ndarray 0.15.6",
"petgraph",
"serde_json",
]
......@@ -346,7 +346,7 @@ name = "eig-expr"
version = "0.1.0"
dependencies = [
"fnv",
"ndarray",
"ndarray 0.16.0",
"nom",
"num-complex",
"num-traits",
......@@ -600,6 +600,21 @@ dependencies = [
]
[[package]]
name = "ndarray"
version = "0.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "087ee1ca8a7c22830c2bba4a96ed8e72ce0968ae944349324d52522f66aa3944"
dependencies = [
"matrixmultiply",
"num-complex",
"num-integer",
"num-traits",
"portable-atomic",
"portable-atomic-util",
"rawpointer",
]
[[package]]
name = "nom"
version = "7.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
......@@ -692,6 +707,21 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bda66fc9667c18cb2758a2ac84d1167245054bcf85d5d1aaa6923f45801bdd02"
[[package]]
name = "portable-atomic"
version = "1.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "da544ee218f0d287a911e9c99a39a8c9bc8fcad3cb8db5959940044ecfc67265"
[[package]]
name = "portable-atomic-util"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fcdd8420072e66d54a407b3316991fe946ce3ab1083a7f575b2463866624704d"
dependencies = [
"portable-atomic",
]
[[package]]
name = "proc-macro2"
version = "1.0.86"
source = "registry+https://github.com/rust-lang/crates.io-index"
......
......@@ -296,3 +296,4 @@ impl PartialEq for AoeModel {
b
}
}
// above should as same as in sparrowzz
\ No newline at end of file
......@@ -13,4 +13,4 @@ rayon = "1.10"
rustc-hash = "2.0"
num-traits = "0.2"
num-complex = "0.4"
ndarray = "0.15"
\ No newline at end of file
ndarray = "0.16"
\ No newline at end of file
......@@ -625,71 +625,6 @@ impl Deref for Expr {
}
}
/// A trait of a source of variables (and constants) and functions for substitution into an
/// evaluated expression.
///
/// A simplest way to create a custom context provider is to use [`Context`](struct.Context.html).
///
/// ## Advanced usage
///
/// Alternatively, values of variables/constants can be specified by tuples `(name, value)`,
/// `std::collections::HashMap` or `std::collections::BTreeMap`.
///
/// use {ContextProvider, Context};
///
/// let mut ctx = Context::new(); // built-ins
/// ctx.var("x", 2.); // insert a new variable
/// assert_eq!(ctx.get_var("pi"), Some(std::f64::consts::PI));
///
/// let myvars = ("x", 2.); // tuple as a ContextProvider
/// assert_eq!(myvars.get_var("x"), Some(2f64));
///
/// // HashMap as a ContextProvider
/// let mut varmap = std::collections::HashMap::new();
/// varmap.insert("x", 2.);
/// varmap.insert("y", 3.);
/// assert_eq!(varmap.get_var("x"), Some(2f64));
/// assert_eq!(varmap.get_var("z"), None);
///
/// Custom functions can be also defined.
///
/// use {ContextProvider, Context};
///
/// let mut ctx = Context::new(); // built-ins
/// ctx.func2("phi", |x, y| x / (y * y));
///
/// assert_eq!(ctx.eval_func("phi", &[2., 3.]), Ok(2. / (3. * 3.)));
///
/// A `ContextProvider` can be built by combining other contexts:
///
/// use Context;
///
/// let bins = Context::new(); // built-ins
/// let mut funcs = Context::empty(); // empty context
/// funcs.func2("phi", |x, y| x / (y * y));
/// let myvars = ("x", 2.);
///
/// // contexts can be combined using tuples
/// let ctx = ((myvars, bins), funcs); // first context has preference if there's duplicity
///
/// assert_eq!(eval_str_with_context("x * pi + phi(1., 2.)", ctx).unwrap(), 2. *
/// std::f64::consts::PI + 1. / (2. * 2.));
///
#[doc(hidden)]
#[cfg(feature = "with_rand")]
pub fn random() -> f64 {
use rand::Rng;
rand::thread_rng().gen::<f64>()
}
#[doc(hidden)]
#[cfg(feature = "with_rand")]
pub fn random2(lower: f64, upper: f64) -> f64 {
use rand::Rng;
rand::thread_rng().gen_range(lower..upper)
}
#[doc(hidden)]
pub fn max_array(xs: &[f64]) -> f64 {
xs.iter().fold(f64::NEG_INFINITY, |m, &x| m.max(x))
......@@ -1064,8 +999,6 @@ impl<'a> Context<'a> {
ctx.var("pi", consts::PI);
ctx.var("PI", consts::PI);
ctx.var("e", consts::E);
#[cfg(feature = "with_rand")]
ctx.func0("rand", random);
ctx.func1("sqrt", f64::sqrt);
ctx.func1("exp", f64::exp);
ctx.func1("ln", f64::ln);
......@@ -1088,8 +1021,6 @@ impl<'a> Context<'a> {
ctx.func1("round", f64::round);
ctx.func1("signum", f64::signum);
ctx.func2("atan2", f64::atan2);
#[cfg(feature = "with_rand")]
ctx.func2("rand2", random2);
ctx.funcn("max", max_array, 1..);
ctx.funcn("min", min_array, 1..);
ctx
......@@ -1303,151 +1234,4 @@ impl<'a> ContextProvider for Context<'a> {
.get(name)
.map_or(Err(FuncEvalError::UnknownFunction), |f| f(args))
}
}
#[cfg(test)]
mod tests {
use std::str::FromStr;
use Error;
use super::*;
#[test]
fn test_eval() {
assert_eq!(eval_str("3 -3"), Ok(0.));
assert_eq!(eval_str("2 + 3"), Ok(5.));
assert_eq!(eval_str("2 + (3 + 4)"), Ok(9.));
assert_eq!(eval_str("-2^(4 - 3) * (3 + 4)"), Ok(-14.));
assert_eq!(eval_str("-2*3! + 1"), Ok(-11.));
assert_eq!(eval_str("-171!"), Ok(f64::MIN));
assert_eq!(eval_str("150!/148!"), Ok(22350.));
assert_eq!(eval_str("a + 3"), Err(Error::UnknownVariable("a".into())));
assert_eq!(eval_str("round(sin (pi) * cos(0))"), Ok(0.));
assert_eq!(eval_str("round( sqrt(3^2 + 4^2)) "), Ok(5.));
assert_eq!(eval_str("max(1.)"), Ok(1.));
assert_eq!(eval_str("max(1., 2., -1)"), Ok(2.));
assert_eq!(eval_str("min(1., 2., -1)"), Ok(-1.));
assert_eq!(
eval_str("sin(1.) + cos(2.)"),
Ok((1f64).sin() + (2f64).cos())
);
assert_eq!(eval_str("10 % 9"), Ok(10f64 % 9f64));
match eval_str("0.5!") {
Err(Error::EvalError(_)) => {}
_ => panic!("Cannot evaluate factorial of non-integer"),
}
}
#[test]
fn test_builtins() {
assert_eq!(eval_str("atan2(1.,2.)"), Ok((1f64).atan2(2.)));
}
#[test]
fn test_eval_func_ctx() {
use std::collections::{BTreeMap, HashMap};
let y = 5.;
assert_eq!(
eval_str_with_context("phi(2.)", Context::new().func1("phi", |x| x + y + 3.)),
Ok(2. + y + 3.)
);
assert_eq!(
eval_str_with_context(
"phi(2., 3.)",
Context::new().func2("phi", |x, y| x + y + 3.),
),
Ok(2. + 3. + 3.)
);
assert_eq!(
eval_str_with_context(
"phi(2., 3., 4.)",
Context::new().func3("phi", |x, y, z| x + y * z),
),
Ok(2. + 3. * 4.)
);
assert_eq!(
eval_str_with_context(
"phi(2., 3.)",
Context::new().funcn("phi", |xs: &[f64]| xs[0] + xs[1], 2),
),
Ok(2. + 3.)
);
let mut m = HashMap::new();
m.insert("x", 2.);
m.insert("y", 3.);
assert_eq!(eval_str_with_context("x + y", &m), Ok(2. + 3.));
assert_eq!(
eval_str_with_context("x + z", m),
Err(Error::UnknownVariable("z".into()))
);
let mut m = BTreeMap::new();
m.insert("x", 2.);
m.insert("y", 3.);
assert_eq!(eval_str_with_context("x + y", &m), Ok(2. + 3.));
assert_eq!(
eval_str_with_context("x + z", m),
Err(Error::UnknownVariable("z".into()))
);
}
#[test]
fn test_bind() {
let expr = Expr::from_str("x + 3").unwrap();
let func = expr.clone().bind("x").unwrap();
assert_eq!(func(1.), 4.);
assert_eq!(
expr.clone().bind("y").err(),
Some(Error::UnknownVariable("x".into()))
);
let ctx = (("x", 2.), builtin());
let func = expr.bind_with_context(&ctx, "y").unwrap();
assert_eq!(func(1.), 5.);
let expr = Expr::from_str("x + y + 2.").unwrap();
let func = expr.clone().bind2("x", "y").unwrap();
assert_eq!(func(1., 2.), 5.);
assert_eq!(
expr.clone().bind2("z", "y").err(),
Some(Error::UnknownVariable("x".into()))
);
assert_eq!(
expr.bind2("x", "z").err(),
Some(Error::UnknownVariable("y".into()))
);
let expr = Expr::from_str("x + y^2 + z^3").unwrap();
let func = expr.bind3("x", "y", "z").unwrap();
assert_eq!(func(1., 2., 3.), 32.);
let expr = Expr::from_str("sin(x)").unwrap();
let func = expr.bind("x").unwrap();
assert_eq!(func(1.), (1f64).sin());
let expr = Expr::from_str("sin(x,2)").unwrap();
match expr.bind("x") {
Err(Error::Function(_, FuncEvalError::NumberArgs(1))) => {}
_ => panic!("bind did not error"),
}
let expr = Expr::from_str("hey(x,2)").unwrap();
match expr.bind("x") {
Err(Error::Function(_, FuncEvalError::UnknownFunction)) => {}
_ => panic!("bind did not error"),
}
}
#[test]
fn hash_context() {
let y = 0.;
{
let z = 0.;
let mut ctx = Context::new();
ctx.var("x", 1.).func1("f", |x| x + y).func1("g", |x| x + z);
ctx.func2("g", |x, y| x + y);
}
}
}
}
\ No newline at end of file
......@@ -302,19 +302,6 @@ pub fn new_cx_angle(r: Complex64, i: Complex64) -> Complex64 {
}
#[doc(hidden)]
#[cfg(feature = "with_rand")]
pub fn random() -> Complex64 {
use rand::Rng;
Complex64::new(rand::thread_rng().gen::<f64>(), 0.)
}
#[doc(hidden)]
#[cfg(feature = "with_rand")]
pub fn random2(lower: Complex64, upper: Complex64) -> Complex64 {
use rand::Rng;
Complex64::new(rand::thread_rng().gen_range(lower.re..upper.re), 0.)
}
#[doc(hidden)]
pub fn abs(v: Complex64) -> Complex64 {
Complex64::new(v.norm(), 0.)
}
......@@ -395,8 +382,6 @@ impl<'a> ContextCx<'a> {
ctx.var("pi", PI);
ctx.var("PI", PI);
ctx.var("e", std::f64::consts::E);
#[cfg(feature = "with_rand")]
ctx.func0("rand", random);
ctx.func1("abs", abs);
ctx.func1("sqrt", Complex64::sqrt);
ctx.func1("exp", Complex64::exp);
......@@ -429,8 +414,6 @@ impl<'a> ContextCx<'a> {
ctx.func2("c1", new_cx_rad);
// 用角度建立复数
ctx.func2("c2", new_cx_angle);
#[cfg(feature = "with_rand")]
ctx.func2("rand2", random2);
ctx.funcn("max", max_array, 1..);
ctx.funcn("min", min_array, 1..);
ctx
......@@ -616,117 +599,4 @@ impl<'a> ContextProvider for ContextCx<'a> {
.get(name)
.map_or(Err(FuncEvalError::UnknownFunction), |f| f(args))
}
}
#[cfg(test)]
mod tests {
use std::f64::consts::PI;
use std::ops::Mul;
use std::str::FromStr;
use approx::assert_relative_eq;
use ndarray::array;
use num_complex::{Complex, Complex64};
use crate::Expr;
use crate::expr_complex::ContextCx;
#[test]
fn it_works() {
let expr = Expr::from_str("1+2").unwrap();
let r = expr.eval_complex();
assert_eq!(r, Ok(Complex64::new(3., 0.)));
let expr = Expr::from_str("c(0,1)+c(2,0.)").unwrap();
let r = expr.eval_complex();
assert_eq!(r, Ok(Complex64::new(2., 1.)));
let expr = Expr::from_str("abs(c(0,1))+c(2,0.)").unwrap();
let r = expr.eval_complex();
assert_eq!(r, Ok(Complex64::new(3., 0.)));
let mut cc = ContextCx::new();
cc.var_cx("a", Complex::new(1., 1.));
let expr = Expr::from_str("a+c(2,0.)").unwrap();
let r = expr.eval_complex_with_ctx(cc.clone());
assert_eq!(r, Ok(Complex64::new(3., 1.)));
}
#[test]
fn test_2_1() {
let mut cc = ContextCx::new();
cc.var("GMRabc", 0.00744);
cc.var("GMRn", 0.00248);
cc.var("rabc", 0.190);
cc.var("rn", 0.368);
cc.var("Dab", 0.7622);
cc.var("Dbc", 1.3720);
cc.var("Dca", 2.1342);
cc.var("Dan", 1.7247);
cc.var("Dbn", 1.3025);
cc.var("Dcn", 1.5244);
// let zaa = Complex64::new(rabc + 0.0493, 0.0628 * (((1.0 / GMRabc) as f64).ln() + 8.02517));
let expr =
Expr::from_str("c(rabc + 0.0493, 0.0628 * (ln(1.0 / GMRabc) + 8.02517))").unwrap();
let zaa = expr.eval_complex_with_ctx(cc.clone()).unwrap();
assert_relative_eq!(zaa.re, 0.2393, max_relative = 1e-4);
assert_relative_eq!(zaa.im, 0.8118, max_relative = 1e-4);
let expr = Expr::from_str("c(0.0493, 0.0628 * (ln(1.0 / Dab) + 8.02517))").unwrap();
let zab = expr.eval_complex_with_ctx(cc.clone()).unwrap();
assert_relative_eq!(zab.im, 0.5210, max_relative = 1e-4);
let expr = Expr::from_str("c(0.0493, 0.0628 * (ln(1.0 / Dca) + 8.02517))").unwrap();
let zac = expr.eval_complex_with_ctx(cc.clone()).unwrap();
assert_relative_eq!(zac.im, 0.4564, max_relative = 1e-4);
let expr = Expr::from_str("c(0.0493, 0.0628 * (ln(1.0 / Dbc) + 8.02517))").unwrap();
let zbc = expr.eval_complex_with_ctx(cc.clone()).unwrap();
assert_relative_eq!(zbc.im, 0.4841, max_relative = 1e-4);
let expr = Expr::from_str("c(0.0493, 0.0628 * (ln(1.0 / Dan) + 8.02517))").unwrap();
let zan = expr.eval_complex_with_ctx(cc.clone()).unwrap();
assert_relative_eq!(zan.im, 0.4698, max_relative = 1e-3);
let expr = Expr::from_str("c(0.0493, 0.0628 * (ln(1.0 / Dbn) + 8.02517))").unwrap();
let zbn = expr.eval_complex_with_ctx(cc.clone()).unwrap();
assert_relative_eq!(zbn.im, 0.4874, max_relative = 1e-4);
let expr = Expr::from_str("c(0.0493, 0.0628 * (ln(1.0 / Dcn) + 8.02517))").unwrap();
let zcn = expr.eval_complex_with_ctx(cc.clone()).unwrap();
assert_relative_eq!(zcn.im, 0.4775, max_relative = 1e-4);
let expr = Expr::from_str("c(rn + 0.0493, 0.0628 * (ln(1.0 / GMRn) + 8.02517))").unwrap();
let znn = expr.eval_complex_with_ctx(cc.clone()).unwrap();
assert_relative_eq!(znn.re, 0.4173, max_relative = 1e-4);
assert_relative_eq!(znn.im, 0.8807, max_relative = 1e-4);
let zij = array![[zaa, zab, zac], [zab, zaa, zbc], [zac, zbc, zaa]];
let zin = array![[zan], [zbn], [zcn]];
let znj = array![zan, zbn, zcn];
let zabc = zij - zin.mul(array![Complex64::new(1.0, 0.0)] / znn).mul(znj);
println!("{:?}", zabc);
let a = Complex64::new(f64::cos(2.0 * PI / 3.0), f64::sin(2.0 * PI / 3.0));
let matrixes = array![
[
Complex64::new(1.0, 0.0),
Complex64::new(1.0, 0.0),
Complex64::new(1.0, 0.0)
],
[Complex64::new(1.0, 0.0), a * a, a],
[Complex64::new(1.0, 0.0), a, a * a]
];
let matrixes_inv = array![
[
Complex64::new(1.0 / 3.0, 0.0),
Complex64::new(1.0 / 3.0, 0.0),
Complex64::new(1.0 / 3.0, 0.0)
],
[
Complex64::new(1.0 / 3.0, 0.0),
a * Complex64::new(1.0 / 3.0, 0.0),
a * a * Complex64::new(1.0 / 3.0, 0.0)
],
[
Complex64::new(1.0 / 3.0, 0.0),
a * a * Complex64::new(1.0 / 3.0, 0.0),
a * Complex64::new(1.0 / 3.0, 0.0)
]
];
let temp = matrixes_inv.dot(&zabc);
let z012 = temp.dot(&matrixes);
assert_relative_eq!(z012.get([0, 0]).unwrap().re, 0.5050, max_relative = 1e-3);
}
}
}
\ No newline at end of file
......@@ -2,8 +2,6 @@ use std::f64::consts::PI;
use fnv::FnvHashMap;
use ndarray::{Array, Ix1, Ix2, IxDyn};
use num_complex::Complex64;
#[cfg(feature = "enable_ndarray_blas")]
use ndarray_linalg::*;
use crate::{CtxProvider, Expr, Operation, Token::*};
use crate::{ContextProvider, Error, factorial, FuncEvalError, MyCx, MyF};
......@@ -40,14 +38,6 @@ impl ContextProvider for CtxProvider {
fn eval_func_tensor_cx(&self, name: &str, args: &[MyCx]) -> Result<MyCx, FuncEvalError> {
DEFAULT_CONTEXT_TENSOR.with(|ctx| ctx.eval_func_tensor_cx(name, args))
}
#[cfg(feature = "enable_ndarray_blas")]
fn matrix_inv(&self, v: &ndarray::Array2<f64>) -> Result<ndarray::Array2<f64>, FuncEvalError> {
v.inv().map_err(|_| FuncEvalError::NumberArgs(0))
}
#[cfg(feature = "enable_ndarray_blas")]
fn matrix_inv_cx(&self, v: &ndarray::Array2<Complex64>) -> Result<ndarray::Array2<Complex64>, FuncEvalError> {
v.inv().map_err(|_| FuncEvalError::NumberArgs(0))
}
}
impl Expr {
......@@ -1243,15 +1233,6 @@ impl<'a> ContextProvider for ContextTensor<'a> {
self.tensors_cx.get(name).cloned()
}
#[cfg(feature = "enable_ndarray_blas")]
fn matrix_inv(&self, v: &ndarray::Array2<f64>) -> Result<ndarray::Array2<f64>, FuncEvalError> {
v.inv().map_err(|_| FuncEvalError::NumberArgs(0))
}
#[cfg(feature = "enable_ndarray_blas")]
fn matrix_inv_cx(&self, v: &ndarray::Array2<Complex64>) -> Result<ndarray::Array2<Complex64>, FuncEvalError> {
v.inv().map_err(|_| FuncEvalError::NumberArgs(0))
}
fn eval_func_tensor(&self, name: &str, args: &[MyF]) -> Result<MyF, FuncEvalError> {
let mut floats = Vec::with_capacity(args.len());
for arg in args {
......@@ -1380,7 +1361,7 @@ impl<'a> ContextProvider for ContextTensor<'a> {
"size" => TsfnBasic::ts_size(args),
"sparse" => TsfnBasic::ts_sparse(args),
"diag" => TsfnBasic::ts_diag(args),
"trace" => TsfnBasic::ts_trace(args),
// "trace" => TsfnBasic::ts_trace(args),
_ => Err(FuncEvalError::UnknownFunction),
}
}
......@@ -1549,182 +1530,10 @@ impl<'a> ContextProvider for ContextTensor<'a> {
.mapv(|a| self.ctx_cx.eval_func_cx("conj", &[a]).unwrap())))
},
"size" => TsfnBasic::ts_size_cx(args),
"eig" => TsfnBasic::ts_eig(args),
"diag" => TsfnBasic::ts_diag_cx(args),
"trace" => TsfnBasic::ts_trace_cx(args),
// "eig" => TsfnBasic::ts_eig(args),
// "diag" => TsfnBasic::ts_diag_cx(args),
// "trace" => TsfnBasic::ts_trace_cx(args),
_ => Err(FuncEvalError::UnknownFunction),
}
}
}
#[cfg(test)]
mod tests {
use std::str::FromStr;
use ndarray::array;
use super::*;
#[test]
fn test_ndarray() {
let a = array![[[1, 2], [2, 3]], [[1, 2], [3, 4]]];
println!("{:?} dim: {}", a.shape(), a.ndim());
let a = array![1, 2];
println!("{:?} dim: {}", a.shape(), a.ndim());
let a = array![1, 2, 3];
println!("{:?} dim: {}", a.shape(), a.ndim());
let a = array![[1, 2, 3]];
println!("{:?} dim: {}", a.shape(), a.ndim());
let a = array![[1], [2], [3]];
println!("{:?} dim: {}", a.shape(), a.ndim());
}
#[test]
fn test_simple() {
let expr = Expr::from_str("[1,2]+[1,3]").unwrap();
let r = expr.eval_tensor();
assert_eq!(r, Ok(MyF::Tensor(array![2., 5.].into_dyn())));
let expr = Expr::from_str("[1,2]*[1,3]").unwrap();
let r = expr.eval_tensor();
assert_eq!(r, Ok(MyF::Tensor(array![1., 6.].into_dyn())));
let expr = Expr::from_str("[1,2]-[1,3]").unwrap();
let r = expr.eval_tensor();
assert_eq!(r, Ok(MyF::Tensor(array![0., -1.].into_dyn())));
let expr = Expr::from_str("[1,2]/[1,3]").unwrap();
let r = expr.eval_tensor();
assert_eq!(r, Ok(MyF::Tensor(array![1., 2. / 3.].into_dyn())));
let expr = Expr::from_str("[1,2] * 2+[1,3]").unwrap();
let r = expr.eval_tensor();
assert_eq!(r, Ok(MyF::Tensor(array![3., 7.].into_dyn())));
let expr = Expr::from_str("[1,2] +1").unwrap();
let r = expr.eval_tensor();
assert_eq!(r, Ok(MyF::Tensor(array![2., 3.].into_dyn())));
let mut ct = ContextTensor::new();
ct.tensor("a", array![1., 2.].into_dyn().to_owned());
let expr = Expr::from_str("a+[1,3]").unwrap();
let r = expr.eval_tensor_with_ctx(ct);
assert_eq!(r, Ok(MyF::Tensor(array![2., 5.].into_dyn())));
// 度转弧度
let expr = Expr::from_str("deg2rad(a)").unwrap();
let mut ct = ContextTensor::new();
ct.tensor("a", array![1., 2.].into_dyn().to_owned());
let r = expr.eval_tensor_with_ctx(ct);
assert_eq!(r, Ok(MyF::Tensor(array![1. * PI / 180., 2. * PI / 180.].into_dyn())));
// 弧度转度
let expr = Expr::from_str("rad2deg(a)").unwrap();
let mut ct = ContextTensor::new();
ct.tensor("a", array![PI, PI / 2.].into_dyn().to_owned());
let r = expr.eval_tensor_with_ctx(ct);
assert_eq!(r, Ok(MyF::Tensor(array![180., 90.].into_dyn())));
// 四舍五入
let expr = Expr::from_str("round(a)").unwrap();
let mut ct = ContextTensor::new();
ct.tensor("a", array![[1.1, 2.5], [3.6, 4.0]].into_dyn().to_owned());
let r = expr.eval_tensor_with_ctx(ct);
assert_eq!(r, Ok(MyF::Tensor(array![[1., 3.], [4., 4.]].into_dyn())));
// 累加矩阵所有元素
let expr = Expr::from_str("sum_all(a)").unwrap();
let mut ct = ContextTensor::new();
ct.tensor("a", array![[1., 2.], [3., 4.]].into_dyn().to_owned());
let r = expr.eval_tensor_with_ctx(ct);
assert_eq!(r, Ok(MyF::F64(10.)));
let expr = Expr::from_str("sum_all(a)").unwrap();
let mut ct = ContextTensor::new();
ct.tensor_cx("a", array![[Complex64::new(1.0, 1.0), Complex64::new(2.0, 2.0)],
[Complex64::new(3.0, 3.0), Complex64::new(4.0, 4.0)]].into_dyn().to_owned());
let r = expr.eval_tensor_with_ctx_cx(ct);
assert_eq!(r, Ok(MyCx::F64(Complex64::new(10., 10.))));
let mut ct = ContextTensor::new();
ct.var_cx("a", Complex64::new(1.0, 2.0));
ct.var_cx("b", Complex64::new(1.0, 2.0));
let expr = Expr::from_str("[a,b] +[1,3]").unwrap();
let r = expr.eval_tensor_with_ctx_cx(ct);
assert_eq!(
r,
Ok(MyCx::Tensor(
array![Complex64::new(2.0, 2.0), Complex64::new(4.0, 2.0)].into_dyn()
))
);
// test conj
let mut ct = ContextTensor::new();
ct.tensor_cx("a", array![Complex64::new(1.0, 2.0), Complex64::new(3.0, 3.0)].into_dyn().to_owned());
let expr = Expr::from_str("conj(a)").unwrap();
let r = expr.eval_tensor_with_ctx_cx(ct);
assert_eq!(r, Ok(MyCx::Tensor(array![Complex64::new(1.0, -2.0), Complex64::new(3.0, -3.0)].into_dyn())));
// test real
let mut ct = ContextTensor::new();
ct.tensor_cx("a", array![Complex64::new(1.0, 2.0), Complex64::new(3.0, 3.0)].into_dyn().to_owned());
let expr = Expr::from_str("real(a)").unwrap();
let r = expr.eval_tensor_with_ctx_cx(ct);
assert_eq!(r, Ok(MyCx::Tensor(array![Complex64::new(1.0, 0.0), Complex64::new(3.0, 0.0)].into_dyn())));
// test imag
let mut ct = ContextTensor::new();
ct.tensor_cx("a", array![Complex64::new(1.0, 2.0), Complex64::new(3.0, 3.0)].into_dyn().to_owned());
let expr = Expr::from_str("imag(a)").unwrap();
let r = expr.eval_tensor_with_ctx_cx(ct);
assert_eq!(r, Ok(MyCx::Tensor(array![Complex64::new(0., 2.0), Complex64::new(0.0, 3.0)].into_dyn())));
// test angle
let mut ct = ContextTensor::new();
ct.tensor_cx("a", array![Complex64::new(1.0, 1.0), Complex64::new(3.0, 0.)].into_dyn().to_owned());
let expr = Expr::from_str("angle(a)").unwrap();
let r = expr.eval_tensor_with_ctx_cx(ct);
assert_eq!(r, Ok(MyCx::Tensor(array![Complex64::new(PI / 4., 0.0), Complex64::new(0.0, 0.0)].into_dyn())));
// test get
let mut ct = ContextTensor::new();
ct.tensor("a", array![1., 2., 3.].into_dyn().to_owned());
let expr = Expr::from_str("get(a,1)").unwrap();
let r = expr.eval_tensor_with_ctx(ct);
assert_eq!(r, Ok(MyF::F64(2.)));
// let mut ct = ContextTensor::new();
// ct.var("a", 1.);
// let expr = Expr::from_str("get(a)").unwrap();
// let r = expr.eval_tensor_with_ctx(ct);
// assert_eq!(r, Ok(MyF::F64(1.)));
}
#[test]
fn test_slice() {
use ndarray::s;
let mut ct = ContextTensor::new();
ct.tensor("a", array![1., 2., 3.].into_dyn().to_owned());
let expr = Expr::from_str("slice(a,0)").unwrap();
let r = expr.eval_tensor_with_ctx(ct);
assert_eq!(r, Ok(MyF::Tensor(array![1., 2., 3.].slice(s![0]).into_dyn().to_owned())));
let mut ct = ContextTensor::new();
let a = array![[1., 2., 3.],[4.,5.,6.],[7.,8.,9.]].into_dyn().to_owned();
ct.tensor("a", a.clone());
let expr = Expr::from_str("slice(a,[0,1],[0])").unwrap();
let r = expr.eval_tensor_with_ctx(ct);
assert_eq!(r, Ok(MyF::Tensor(a.slice(s![..1, ..]).into_dyn().to_owned())));
let mut ct = ContextTensor::new();
let a = array![[1., 2., 3.],[4.,5.,6.],[7.,8.,9.]].into_dyn().to_owned();
ct.tensor("a", a.clone());
let expr = Expr::from_str("slice(a,[0,3,2],[0,3,2])").unwrap();
let r = expr.eval_tensor_with_ctx(ct);
assert_eq!(r, Ok(MyF::Tensor(a.slice(s![..3;2, 0..3;2]).into_dyn().to_owned())));
let mut ct = ContextTensor::new();
let a = array![[Complex64::new(1., 1.), Complex64::new(2., 2.), Complex64::new(3., 3.)],
[Complex64::new(4., 4.), Complex64::new(5., 5.), Complex64::new(6., 6.)],
[Complex64::new(7., 7.), Complex64::new(8., 8.), Complex64::new(9., 9.)]].into_dyn().to_owned();
ct.tensor_cx("a", a.clone());
let expr = Expr::from_str("slice(a,[0,2],[0,2])").unwrap();
let r = expr.eval_tensor_with_ctx_cx(ct);
println!("{:?}", r);
assert_eq!(r, Ok(MyCx::Tensor(a.slice(s![..2, 0..2]).into_dyn().to_owned())));
let mut ct = ContextTensor::new();
let a = array![Complex64::new(1., 2.)].into_dyn();
ct.tensor_cx("a", a.clone());
let expr = Expr::from_str("get(a, 0)").unwrap();
let r = expr.eval_tensor_with_ctx_cx(ct);
println!("{:?}", r);
assert_eq!(r, Ok(MyCx::F64(Complex64::new(1., 2.))));
}
}
}
\ No newline at end of file
......@@ -16,9 +16,6 @@ pub mod tokenizer;
pub mod shuntingyard;
pub mod tsfn_basic;
#[cfg(any(feature = "test", feature = "enable_ndarray_blas"))]
extern crate ndarray_linalg;
#[derive(Debug, Clone, PartialEq)]
pub enum MyF {
F64(f64),
......@@ -430,361 +427,3 @@ pub fn parse_exprs(s: &str) -> Option<Vec<(String, Expr)>> {
}
Some(exprs)
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use crate::expr::{eval_str, Context};
use crate::Expr;
#[test]
fn it_works() {
let r = eval_str("1 + 2").unwrap();
assert_eq!(r, 3.0);
}
#[test]
fn test1() {
assert_eq!(eval_str("3+5*(6-1)-18/3^2").unwrap(), 26.0);
assert_eq!(eval_str("3+5*(6-1)-16%3").unwrap(), 27.0);
assert_eq!(eval_str("1/(2*2+3*3)^0.5").unwrap(), 0.2773500981126146);
let r = eval_str(
"(0 - 1.5625859498977661)/((0 - 1.5625859498977661)*(0 - 1.5625859498977661)+(0 - 0.444685697555542)*(0 - 0.444685697555542))^0.5");
assert!(r.is_ok());
let r = eval_str("(0 - 1E-2)/((0 - 2e-1)*(0 - 1)+ 2E-2^0.5)");
assert!(r.is_ok());
let r = eval_str("1.0 - 3.0 * sin(10 - 3) / 2.0").unwrap();
assert_eq!(r, 1.0 - 3.0 * (10.0_f64 - 3.0_f64).sin() / 2.0);
let r = eval_str("1.0 - 3.0^2 * cos(10 - 3) / 2.0 - sqrt(4)").unwrap();
assert_eq!(
r,
1.0 - 9.0 * (10.0_f64 - 3.0_f64).cos() / 2.0 - 4.0_f64.sqrt()
);
let r = eval_str("1.0 - exp(2.2)^2 * tan(10 - 3) / 2.0 - sqrt(4)").unwrap();
assert!(
(r - (1.0
- 2.2_f64.exp() * 2.2_f64.exp() * (10.0_f64 - 3.0_f64).tan() / 2.0
- 4.0_f64.sqrt()))
.abs()
< 1e-5
);
}
#[test]
fn test2() {
let r: f64 = eval_str("3+5.5").unwrap();
assert_eq!(r, 8.5);
let r: f64 = eval_str("10.0/(10.0*10.0 + 2.0 * 2.0)^0.5").unwrap();
assert!((r - 0.98).abs() < 0.01);
}
#[test]
fn test3() {
let expr: Expr = "10.0/(x*10.0 + 2.0 * 2.0)^0.5".parse().unwrap();
let func = expr.bind("x").unwrap();
let mut vm = HashMap::new();
vm.insert("var1", 10.0);
let ans = func((vm.get("var1").copied()).unwrap());
assert!((ans - 0.98).abs() < 0.01);
}
#[test]
fn test4() {
assert_eq!(eval_str("12-5*2>10").unwrap(), 0.0);
assert_eq!(eval_str("(12-5)*2>10").unwrap(), 1.0);
assert_eq!(eval_str("(12-5)*2>10").unwrap(), 1.0);
assert_eq!(eval_str("30<40/8*5.5").unwrap(), 0.0);
assert_eq!(eval_str("20/3-5<2").unwrap(), 1.0);
assert_eq!(eval_str("25.5 == (6.5-1.5)*5").unwrap(), 0.0);
assert_eq!(eval_str("25 == (6.5-1.5)*5").unwrap(), 1.0);
let expr: Expr = "var1 == (6.5-1.5)*5 ".parse().unwrap();
let func = expr.bind("var1").unwrap();
assert_eq!(func(25.5), 0.0);
let expr: Expr = "var1 + var2>40/8*5.5 ".parse().unwrap();
let mut context = Context::new();
context.var("var1", 20.0);
context.var("var2", 10.0);
assert_eq!(expr.eval_with_context(context).unwrap(), 1.0);
assert_eq!(eval_str("100 - 20 > 2.77 || 20 - 100 > 2.77").unwrap(), 1.0);
}
/*abs绝对值运算测试*/
#[test]
fn test5() {
assert_eq!(eval_str("5.5*abs(-4.2)").unwrap(), 23.1);
assert_eq!(eval_str("3+5*abs(5-7)").unwrap(), 13.);
}
/* 向下取整函数"floor()"测试*/
#[test]
fn test7() {
assert_eq!(eval_str("floor(2.5)*6").unwrap(), 12.0);
assert_eq!(eval_str("5+floor(-1.1)*5").unwrap(), -5.0);
assert_eq!(eval_str("2.5*floor((2.5-1)+3.02)").unwrap(), 10.0);
}
/*提取数据中最大值max,最小值min测试*/
#[test]
fn test8() {
assert_eq!(eval_str("max(7.5,14.8,9.8,2.0,5.5)").unwrap(), 14.8);
assert_eq!(
eval_str("max(47,14,4,70,49,13,35,86,90,71)").unwrap(),
90_f64
);
assert_eq!(
eval_str("max(2.4,9.2,8.6,5.5,8.5,6.2,2.7,0.01,2.5,5.4)").unwrap(),
9.2
);
let mut vm = HashMap::new();
vm.insert("x1", 10.0);
let expr4: Expr = "max(15.0,x1,1)".parse().unwrap();
let func = expr4.bind("x1").unwrap();
assert_eq!(func((vm.get("x1").copied()).unwrap()), 15.0);
assert_eq!(eval_str("min(7.5, 14.8, 9.8, 2.0, 5.5)").unwrap(), 2.0);
assert_eq!(
eval_str("min(47,14,4,70,49,13,35,86,90,71)").unwrap(),
4_f64
);
assert_eq!(
eval_str("min(2.4,9.2,8.6,5.5,8.5,6.2,2.7,0.01,2.5,5.4)").unwrap(),
0.01
);
let expr: Expr = "min(15.0,x,1)".parse().unwrap();
let func = expr.bind("x").unwrap();
assert_eq!(func((vm.get("x1").copied()).unwrap()), 1_f64);
assert_eq!(eval_str("2*3+max(0,1,0.5)").unwrap(), 7_f64);
assert_eq!(eval_str("2*3+min(0,1,0.5)").unwrap(), 6_f64);
assert_eq!(eval_str("2+3*max(1, 2.5, 0.5)").unwrap(), 9.5);
assert_eq!(eval_str("2+3*min(1, 2.5, 0.5)").unwrap(), 3.5);
assert_eq!(eval_str("max(0,max(1,2))").unwrap(), 2_f64);
assert_eq!(eval_str("max(0,1+max(1,2))").unwrap(), 3_f64);
assert_eq!(eval_str("max(2,min(3,4)*max(1,2))").unwrap(), 6_f64);
assert_eq!(eval_str("max(max(1,0), min(1,0))").unwrap(), 1_f64);
assert_eq!(
eval_str("max(max(1,0) + min(1,0), min(1,0) + max(3,5))").unwrap(),
5_f64
);
assert_eq!(
eval_str("min(max(8,2) - min(5,4), min(1,3) - max(3,5))").unwrap(),
-4.0
);
assert_eq!(
eval_str("max(max(6,2) * min(5,4), min(1,0) * max(3,5))").unwrap(),
24_f64
);
assert_eq!(
eval_str("min(max(8,2) / min(5,4), min(1,3) * max(3,5))").unwrap(),
2_f64
);
vm.insert("x2", 4_f64);
let expr21: Expr = "min(max(8,2) / x, min(1,3) * max(3,5))".parse().unwrap();
let func = expr21.bind("x").unwrap();
assert_eq!(func((vm.get("x2").copied()).unwrap()), 2_f64);
}
// bool运算符测试">="和"<="
#[test]
fn test9() {
assert_eq!(eval_str("4 >= 5").unwrap(), 0.0);
assert_eq!(eval_str("7 >= 3").unwrap(), 1.0);
assert_eq!(eval_str("4.5 <= 7.5").unwrap(), 1.0);
let expr: Expr = "x1 >= 1".parse().unwrap();
assert_eq!(expr.bind("x1").unwrap()(10.0), 1.0);
let expr: Expr = "x1 >= 100".parse().unwrap();
assert_eq!(expr.bind("x1").unwrap()(10.0), 0.0);
let expr: Expr = "x1 <= 100".parse().unwrap();
assert_eq!(expr.bind("x1").unwrap()(10.0), 1.0);
let expr: Expr = "x1 <= 1".parse().unwrap();
assert_eq!(expr.bind("x1").unwrap()(10.0), 0.0);
assert_eq!(eval_str("7.5 <= 7.5").unwrap(), 1.0);
assert_eq!(eval_str("5.6 >= 5.6").unwrap(), 1.0);
let expr: Expr = "10 <= x1".parse().unwrap();
assert_eq!(expr.bind("x1").unwrap()(10.0), 1.0);
// test not
let expr: Expr = "~~x1".parse().unwrap();
assert_eq!(expr.bind("x1").unwrap()(10.0), 0.0);
let expr: Expr = "~~x1".parse().unwrap();
assert_eq!(expr.bind("x1").unwrap()(-10.0), 1.0);
}
// bool算符测试"!="
#[test]
fn test10() {
assert_eq!(eval_str("1!=1").unwrap(), 0.0);
assert_eq!(eval_str("0.5 != 1").unwrap(), 1.0);
assert_eq!(eval_str("1.1 != 1").unwrap(), 1.0);
let expr: Expr = "x1 != 9.9".parse().unwrap();
assert_eq!(expr.bind("x1").unwrap()(10.0), 1.0);
let expr: Expr = "x1 != 10.1".parse().unwrap();
assert_eq!(expr.bind("x1").unwrap()(10.0), 1.0);
let expr: Expr = "x1 != 10".parse().unwrap();
assert_eq!(expr.bind("x1").unwrap()(10.0), 0.0);
}
// 负号
#[test]
fn test11() {
assert_eq!(eval_str("2 + 5 - 4").unwrap(), 3.0);
let expr: Expr = "x1-5".parse().unwrap();
assert_eq!(expr.bind("x1").unwrap()(10.0), 5.0);
}
// 四舍五入算符round
#[test]
fn test12() {
assert_eq!(3_f64, eval_str("round(3.2)").unwrap());
assert_eq!(4_f64, eval_str("round(3.6)").unwrap());
assert_eq!(7.5, eval_str("round(5.2) + 2.5").unwrap());
assert_eq!(7.5, eval_str("2.5 + round(4.85)").unwrap());
assert_eq!(12.5, eval_str("2.5 * round(4.85)").unwrap());
assert_eq!(10_f64, eval_str("round(4.85 + 5)").unwrap());
let mut r8 = HashMap::new();
r8.insert("x1", 5.3);
let expr: Expr = "round(x)".parse().unwrap();
assert_eq!(expr.clone().bind("x").unwrap()(5.3), 5_f64);
assert_eq!(expr.bind("x").unwrap()(5.3 * 3.0), 16_f64);
}
// 位运算测试与Not
// 与&;或|;反~;异或^^;左移<<;右移>>
#[test]
fn test13() {
assert_eq!(eval_str("!1").unwrap(), 1.0);
assert_eq!(eval_str("!2").unwrap(), 2.0);
assert_eq!(eval_str("!2!=1").unwrap(), 1.0);
assert_eq!(eval_str("~~!2==0").unwrap(), 1.0);
assert_eq!(eval_str("!~~2==1").unwrap(), 1.0);
assert_eq!(eval_str("~~0").unwrap(), 1.0);
assert_eq!(eval_str("~~0.").unwrap(), 1.0);
assert_eq!(eval_str("~~1").unwrap(), 0.0);
assert_eq!(eval_str("~~1.").unwrap(), 0.0);
assert_eq!(eval_str("~~5").unwrap(), 0.0);
assert_eq!(eval_str("~5").unwrap(), -6.0);
assert_eq!(eval_str("3&5").unwrap(), 1.0);
assert_eq!(eval_str("3|5").unwrap(), 7.0);
assert_eq!(eval_str("3^^5").unwrap(), 6.0);
assert_eq!(eval_str("7>>2").unwrap(), 1.0);
assert_eq!(eval_str("11<<2").unwrap(), 44.0);
let expr: Expr = "~x12".parse().unwrap();
assert_eq!(expr.bind("x12").unwrap()(5.0), -6_f64);
let expr: Expr = "($123+2)&3".parse().unwrap();
assert_eq!(expr.bind("$123").unwrap()(5.0), 3_f64);
let expr: Expr = "x123<<2".parse().unwrap();
assert_eq!(expr.bind("x123").unwrap()(5.0), 20_f64);
let expr: Expr = "(x123+1)>>1".parse().unwrap();
assert_eq!(expr.bind("x123").unwrap()(5.0), 3_f64);
}
// 位运算测试:a@b,返回a的二进制第b位
#[test]
fn test14() {
let expr: Expr = "_123@1".parse().unwrap();
assert_eq!(expr.bind("_123").unwrap()(169.0), 1_f64);
let expr: Expr = "_123@2".parse().unwrap();
assert_eq!(expr.bind("_123").unwrap()(169.0), 0_f64);
let expr: Expr = "_123@3".parse().unwrap();
assert_eq!(expr.bind("_123").unwrap()(169.0), 0_f64);
let expr: Expr = "_123@4".parse().unwrap();
assert_eq!(expr.bind("_123").unwrap()(169.0), 1_f64);
let expr: Expr = "_123@5".parse().unwrap();
assert_eq!(expr.bind("_123").unwrap()(169.0), 0_f64);
let expr: Expr = "_123@6".parse().unwrap();
assert_eq!(expr.bind("_123").unwrap()(169.0), 1_f64);
let expr: Expr = "_123@7".parse().unwrap();
assert_eq!(expr.bind("_123").unwrap()(169.0), 0_f64);
let expr: Expr = "_123@8".parse().unwrap();
assert_eq!(expr.bind("_123").unwrap()(169.0), 1_f64);
let expr: Expr = "$123@1".parse().unwrap();
assert_eq!(expr.bind("$123").unwrap()(29.0), 1_f64);
let expr: Expr = "_123@2".parse().unwrap();
assert_eq!(expr.bind("_123").unwrap()(29.0), 0_f64);
let expr: Expr = "_123@3".parse().unwrap();
assert_eq!(expr.bind("_123").unwrap()(29.0), 1_f64);
let expr: Expr = "_123@4".parse().unwrap();
assert_eq!(expr.bind("_123").unwrap()(29.0), 1_f64);
let expr: Expr = "_123@5".parse().unwrap();
assert_eq!(expr.bind("_123").unwrap()(29.0), 1_f64);
let expr: Expr = "_123@6".parse().unwrap();
assert_eq!(expr.bind("_123").unwrap()(29.0), 0_f64);
let expr: Expr = "_123@7".parse().unwrap();
assert_eq!(expr.bind("_123").unwrap()(29.0), 0_f64);
let expr: Expr = "_123@8".parse().unwrap();
assert_eq!(expr.bind("_123").unwrap()(29.0), 0_f64);
let expr: Expr = "_123@1".parse().unwrap();
assert_eq!(expr.bind("_123").unwrap()(90.0), 0_f64);
let expr: Expr = "_123@2".parse().unwrap();
assert_eq!(expr.bind("_123").unwrap()(90.0), 1_f64);
let expr: Expr = "_123@3".parse().unwrap();
assert_eq!(expr.bind("_123").unwrap()(90.0), 0_f64);
let expr: Expr = "_123@4".parse().unwrap();
assert_eq!(expr.bind("_123").unwrap()(90.0), 1_f64);
let expr: Expr = "_123@5".parse().unwrap();
assert_eq!(expr.bind("_123").unwrap()(90.0), 1_f64);
let expr: Expr = "_123@6".parse().unwrap();
assert_eq!(expr.bind("_123").unwrap()(90.0), 0_f64);
let expr: Expr = "_123@7".parse().unwrap();
assert_eq!(expr.bind("_123").unwrap()(90.0), 1_f64);
let expr: Expr = "_123@8".parse().unwrap();
assert_eq!(expr.bind("_123").unwrap()(90.0), 0_f64);
}
// now()获取系统毫秒数测试
#[test]
fn test15() {
let expr: Expr = "now(0)".parse().unwrap();
let mut context = Context::new();
context.func1("now", get_time_stamp1);
assert!(expr.eval_with_context(context).unwrap() > 0.0);
let expr: Expr = "-now(0)".parse().unwrap();
let mut context = Context::new();
context.func1("now", get_time_stamp1);
assert!(expr.eval_with_context(context).unwrap() < 0.0);
let expr: Expr = "now()".parse().unwrap();
let mut context = Context::new();
context.func0("now", get_time_stamp0);
assert!(expr.eval_with_context(context).unwrap() > 0.0);
}
// test random
#[test]
fn test16() {
let expr: Expr = "rand()".parse().unwrap();
let context = Context::new();
let r = expr.eval_with_context(context).unwrap();
assert!(r >= 0.0);
assert!(r < 1.0);
let expr: Expr = "rand2(1,2)".parse().unwrap();
let context = Context::new();
let r = expr.eval_with_context(context).unwrap();
assert!(r >= 1.0);
assert!(r < 2.0);
}
#[test]
fn test17() {
let expr = "f(a, b, f2(), f3(), c)".parse::<Expr>();
println!("{:?}", expr);
assert!(expr.is_ok());
}
fn get_time_stamp1(_: f64) -> f64 {
let now = std::time::SystemTime::now();
now.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis() as f64
}
fn get_time_stamp0() -> f64 {
let now = std::time::SystemTime::now();
now.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis() as f64
}
}
......@@ -620,342 +620,4 @@ pub fn rpn_to_latex(input: &[Token]) -> Result<String, RPNError> {
}
}
Ok(output)
}
#[cfg(test)]
mod tests {
use crate::Operation::*;
use crate::Token::*;
use super::*;
#[test]
fn test_to_rpn() {
assert_eq!(to_rpn(&[Number(1.)]), Ok(vec![Number(1.)]));
assert_eq!(
to_rpn(&[Number(1.), Binary(Plus), Number(2.)]),
Ok(vec![Number(1.), Number(2.), Binary(Plus)])
);
assert_eq!(
to_rpn(&[Unary(Minus), Number(1.), Binary(Pow), Number(2.)]),
Ok(vec![Number(1.), Number(2.), Binary(Pow), Unary(Minus)])
);
assert_eq!(
to_rpn(&[Number(1.), Unary(Fact), Binary(Pow), Number(2.)]),
Ok(vec![Number(1.), Unary(Fact), Number(2.), Binary(Pow)])
);
assert_eq!(
to_rpn(&[
Number(1.),
Unary(Fact),
Binary(Div),
LParen,
Number(2.),
Binary(Plus),
Number(3.),
RParen,
Unary(Fact)
]),
Ok(vec![
Number(1.),
Unary(Fact),
Number(2.),
Number(3.),
Binary(Plus),
Unary(Fact),
Binary(Div),
])
);
assert_eq!(
to_rpn(&[
Number(3.),
Binary(Minus),
Number(1.),
Binary(Times),
Number(2.)
]),
Ok(vec![
Number(3.),
Number(1.),
Number(2.),
Binary(Times),
Binary(Minus),
])
);
assert_eq!(
to_rpn(&[
LParen,
Number(3.),
Binary(Minus),
Number(1.),
RParen,
Binary(Times),
Number(2.)
]),
Ok(vec![
Number(3.),
Number(1.),
Binary(Minus),
Number(2.),
Binary(Times),
])
);
assert_eq!(
to_rpn(&[
Number(1.),
Binary(Minus),
Unary(Minus),
Unary(Minus),
Number(2.)
]),
Ok(vec![
Number(1.),
Number(2.),
Unary(Minus),
Unary(Minus),
Binary(Minus),
])
);
assert_eq!(
to_rpn(&[Var("x".into()), Binary(Plus), Var("y".into())]),
Ok(vec![Var("x".into()), Var("y".into()), Binary(Plus)])
);
assert_eq!(
to_rpn(&[
Func("max".into(), None),
Func("sin".into(), None),
Number(1f64),
RParen,
Comma,
Func("cos".into(), None),
Number(2f64),
RParen,
RParen
]),
Ok(vec![
Number(1f64),
Func("sin".into(), Some(1)),
Number(2f64),
Func("cos".into(), Some(1)),
Func("max".into(), Some(2)),
])
);
assert_eq!(to_rpn(&[Binary(Plus)]), Err(RPNError::NotEnoughOperands(0)));
assert_eq!(
to_rpn(&[Func("f".into(), None), Binary(Plus), RParen]),
Err(RPNError::NotEnoughOperands(0))
);
assert_eq!(
to_rpn(&[Var("x".into()), Number(1.)]),
Err(RPNError::TooManyOperands)
);
assert_eq!(to_rpn(&[LParen]), Err(RPNError::MismatchedLParen(0)));
assert_eq!(to_rpn(&[RParen]), Err(RPNError::MismatchedRParen(0)));
// assert_eq!(
// to_rpn(&[Func("sin".into(), None)]),
// Err(RPNError::MismatchedLParen(0))
// );
assert_eq!(to_rpn(&[Comma]), Err(RPNError::UnexpectedComma(0)));
// assert_eq!(
// to_rpn(&[Func("f".into(), None), Comma]),
// Err(RPNError::MismatchedLParen(0))
// );
assert_eq!(
to_rpn(&[Func("f".into(), None), LParen, Comma, RParen]),
Err(RPNError::UnexpectedComma(2))
);
assert_eq!(
to_rpn(&[
Number(4.),
Binary(Minus),
Number(3.),
Binary(Minus),
Number(1.),
Binary(Times),
Number(2.)
]),
Ok(vec![
Number(4.),
Number(3.),
Binary(Minus),
Number(1.),
Number(2.),
Binary(Times),
Binary(Minus),
])
);
assert_eq!(
to_rpn(&[Tensor(None), Number(1.), Comma, Number(2.), RBracket]),
Ok(vec![Number(1.), Number(2.), Tensor(Some(2))])
);
}
#[test]
fn test_to_infix() {
assert_eq!(
rpn_to_infix(&[
Number(4.),
Number(3.),
Binary(Minus),
Number(1.),
Number(2.),
Binary(Times),
Binary(Minus),
]),
Ok(vec![
Number(4.),
Binary(Minus),
Number(3.),
Binary(Minus),
Number(1.),
Binary(Times),
Number(2.),
])
);
assert_eq!(rpn_to_infix(&[Number(1.)]), Ok(vec![Number(1.)]));
assert_eq!(
rpn_to_infix(&[Number(1.), Number(2.), Binary(Plus)]),
Ok(vec![Number(1.), Binary(Plus), Number(2.)])
);
assert_eq!(
rpn_to_infix(&[Number(1.), Number(2.), Binary(Pow), Unary(Minus)]),
Ok(vec![Unary(Minus), Number(1.), Binary(Pow), Number(2.)])
);
assert_eq!(
rpn_to_infix(&[Number(1.), Unary(Fact), Number(2.), Binary(Pow)]),
Ok(vec![Number(1.), Unary(Fact), Binary(Pow), Number(2.)])
);
assert_eq!(
rpn_to_infix(&[
Number(1.),
Unary(Fact),
Number(2.),
Number(3.),
Binary(Plus),
Unary(Fact),
Binary(Div)
]),
Ok(vec![
Number(1.),
Unary(Fact),
Binary(Div),
LParen,
Number(2.),
Binary(Plus),
Number(3.),
RParen,
Unary(Fact),
])
);
assert_eq!(
rpn_to_infix(&[
Number(3.),
Number(1.),
Number(2.),
Binary(Times),
Binary(Minus)
]),
Ok(vec![
Number(3.),
Binary(Minus),
Number(1.),
Binary(Times),
Number(2.),
])
);
assert_eq!(
rpn_to_infix(&[
Number(3.),
Number(1.),
Binary(Minus),
Number(2.),
Binary(Times)
]),
Ok(vec![
LParen,
Number(3.),
Binary(Minus),
Number(1.),
RParen,
Binary(Times),
Number(2.),
])
);
assert_eq!(
rpn_to_infix(&[
Number(1.),
Number(2.),
Unary(Minus),
Unary(Minus),
Binary(Minus)
]),
Ok(vec![
Number(1.),
Binary(Minus),
LParen,
Unary(Minus),
Unary(Minus),
Number(2.),
RParen,
])
);
assert_eq!(
rpn_to_infix(&[Var("x".into()), Var("y".into()), Binary(Plus)]),
Ok(vec![Var("x".into()), Binary(Plus), Var("y".into())])
);
assert_eq!(
rpn_to_infix(&[
Number(1f64),
Func("sin".into(), Some(1)),
Number(2f64),
Func("cos".into(), Some(1)),
Func("max".into(), Some(2)),
]),
Ok(vec![
Func("max".into(), Some(2)),
Func("sin".into(), Some(1)),
Number(1f64),
RParen,
Comma,
Func("cos".into(), Some(1)),
Number(2f64),
RParen,
RParen,
])
);
assert_eq!(
rpn_to_infix(&[Number(1.), Number(2.), Tensor(Some(2))]),
Ok(vec![Tensor(Some(2)), Number(1.), Comma, Number(2.), RBracket])
);
}
#[test]
fn test_to_latex() {
use crate::Expr;
use crate::tokenizer::tokenize;
use std::str::FromStr;
// let expr = "max((5*1)*x1+3*x2+2*x3+(10-3)*x4+4*x5)";
// let expr = "1*3*x2+sin(8-2)*x3 - cos(pi)< 7";
// let expr = "x1%5+3/3*x2+min(2,5)*x3*2e19 && 1";
// let expr = "(x1+3)^(-2+sin(4^9))^3!--10*x1+x2^(-2+-3)+8<=5*2";
// let expr = "(a^b)!^c";
// let expr = "[1,2]*[3,4]+[a,b]"; //向量
// let expr = "c+-a*-b"; // 单目的括号,c+-a*(-b),这里仍然有些别扭,暂时没有好的办法
// let expr = "1+(2+3*(4*5))"; //乘法和加法多余的括号,可以去括号,但是后缀表达式的顺序会发生更改
let expr = "a*(3*cos(x2)/-8+3)/(a+b)/(c+d)+e==2"; //除法变为分式
let rpn = Expr::from_str(expr).unwrap().rpn;
println!("{:?}", rpn);
let string = rpn_to_string(&rpn).unwrap();
println!("{}", string);
let latex = rpn_to_latex(&rpn).unwrap();
println!("{}", latex);
let rpn_test = to_rpn(&tokenize(string).unwrap()).unwrap();
assert_eq!(rpn, rpn_test);
}
}
}
\ No newline at end of file
//! Tokenizer that converts a mathematical expression in a string form into a series of `Token`s.
//!
//! The underlying parser is build using the [nom] parser combinator crate.
//!
//! The parser should tokenize only well-formed expressions.
//!
//! [nom]: https://crates.io/crates/nom
//!
use std;
use std::fmt;
use std::fmt::{Display, Formatter};
......@@ -344,299 +336,4 @@ impl Display for Token {
Token::Tensor(size) => write!(f, "Tensor({:?})", size),
}
}
}
#[cfg(test)]
mod tests {
use nom::error;
use nom::error::ErrorKind::{Alpha, Digit};
use nom::Err::Error;
use crate::ParseError;
use super::*;
#[test]
fn test_binop() {
assert_eq!(
binop(b"+"),
Ok((&b""[..], Token::Binary(Operation::Plus)))
);
assert_eq!(
binop(b"-"),
Ok((&b""[..], Token::Binary(Operation::Minus)))
);
assert_eq!(
binop(b"*"),
Ok((&b""[..], Token::Binary(Operation::Times)))
);
assert_eq!(
binop(b"/"),
Ok((&b""[..], Token::Binary(Operation::Div)))
);
}
#[test]
fn test_number() {
assert_eq!(
number(b"32143"),
Ok((&b""[..], Token::Number(32143f64)))
);
assert_eq!(
number(b"2."),
Ok((&b""[..], Token::Number(2.0f64)))
);
assert_eq!(
number(b"32143.25"),
Ok((&b""[..], Token::Number(32143.25f64)))
);
assert_eq!(
number(b"0.125e9"),
Ok((&b""[..], Token::Number(0.125e9f64)))
);
assert_eq!(
number(b"20.5E-3"),
Ok((&b""[..], Token::Number(20.5E-3f64)))
);
assert_eq!(
number(b"123423e+50"),
Ok((&b""[..], Token::Number(123423e+50f64)))
);
assert_eq!(
number(b""),
Err(Error(error::Error {
input: &b""[..],
code: Digit
}))
);
assert_eq!(
number(b".2"),
Err(Error(error::Error {
input: &b".2"[..],
code: Digit
}))
);
assert_eq!(
number(b"+"),
Err(Error(error::Error {
input: &b"+"[..],
code: Digit
}))
);
assert_eq!(
number(b"e"),
Err(Error(error::Error {
input: &b"e"[..],
code: Digit
}))
);
assert_eq!(
number(b"1E"),
Err(Error(error::Error {
input: &b""[..],
code: Digit
}))
);
assert_eq!(
number(b"1e+"),
Err(Error(error::Error {
input: &b"+"[..],
code: Digit
}))
);
}
#[test]
fn test_var() {
for &s in ["abc", "U0", "_034", "a_be45EA", "aAzZ_"].iter() {
assert_eq!(
var(s.as_bytes()),
Ok((&b""[..], Token::Var(s.into())))
);
}
for &s in ["\'a\'", "\"U0\"", "\"_034\"", "'*'", "\"+\""].iter() {
assert_eq!(
var(s.as_bytes()),
Ok((&b""[..], tokenize(s).unwrap()[0].clone()))
);
}
assert_eq!(
var(b""),
Err(Error(error::Error {
input: &b""[..],
code: Alpha
}))
);
assert_eq!(
var(b"0"),
Err(Error(error::Error {
input: &b"0"[..],
code: Alpha
}))
);
}
#[test]
fn test_func() {
for &s in ["abc(", "u0(", "_034 (", "A_be45EA ("].iter() {
assert_eq!(
func(s.as_bytes()),
Ok((&b""[..], Token::Func(s[0..s.len() - 1].trim().into(), None)))
);
}
assert_eq!(
func(b""),
Err(Error(error::Error {
input: &b""[..],
code: Alpha
}))
);
assert_eq!(
func(b"("),
Err(Error(error::Error {
input: &b"("[..],
code: Alpha
}))
);
assert_eq!(
func(b"0("),
Err(Error(error::Error {
input: &b"0("[..],
code: Alpha
}))
);
}
#[test]
fn test_tokenize() {
use super::Operation::*;
use super::Token::*;
assert_eq!(tokenize("a"), Ok(vec![Var("a".into())]));
assert_eq!(
tokenize("2 +(3--2) "),
Ok(vec![
Number(2f64),
Binary(Plus),
LParen,
Number(3f64),
Binary(Minus),
Unary(Minus),
Number(2f64),
RParen,
])
);
assert_eq!(
tokenize("-2^ ab0 *12 - C_0"),
Ok(vec![
Unary(Minus),
Number(2f64),
Binary(Pow),
Var("ab0".into()),
Binary(Times),
Number(12f64),
Binary(Minus),
Var("C_0".into()),
])
);
assert_eq!(
tokenize("-sin(pi * 3)^ cos(2) / Func2(x, f(y), z) * _buildIN(y)"),
Ok(vec![
Unary(Minus),
Func("sin".into(), None),
Var("pi".into()),
Binary(Times),
Number(3f64),
RParen,
Binary(Pow),
Func("cos".into(), None),
Number(2f64),
RParen,
Binary(Div),
Func("Func2".into(), None),
Var("x".into()),
Comma,
Func("f".into(), None),
Var("y".into()),
RParen,
Comma,
Var("z".into()),
RParen,
Binary(Times),
Func("_buildIN".into(), None),
Var("y".into()),
RParen,
])
);
assert_eq!(
tokenize("2 % 3"),
Ok(vec![Number(2f64), Binary(Rem), Number(3f64)])
);
assert_eq!(
tokenize("1 + 3! + 1"),
Ok(vec![
Number(1f64),
Binary(Plus),
Number(3f64),
Unary(Fact),
Binary(Plus),
Number(1f64),
])
);
assert_eq!(tokenize("!3"), Err(ParseError::UnexpectedToken(2)));
assert_eq!(tokenize("()"), Err(ParseError::UnexpectedToken(1)));
assert_eq!(tokenize(""), Err(ParseError::MissingArgument));
assert_eq!(tokenize("2)"), Err(ParseError::UnexpectedToken(1)));
assert_eq!(tokenize("2^"), Err(ParseError::MissingArgument));
assert_eq!(tokenize("(((2)"), Err(ParseError::MissingRParen(2)));
assert_eq!(tokenize("f(2,)"), Err(ParseError::UnexpectedToken(1)));
assert_eq!(tokenize("f(,2)"), Err(ParseError::UnexpectedToken(3)));
}
#[test]
fn test_func_with_no_para() {
assert_eq!(
tokenize("f()"),
Ok(vec![Token::Func("f".to_string(), Some(0))])
);
assert_eq!(
tokenize("f( )"),
Ok(vec![Token::Func("f".to_string(), Some(0))])
);
assert!(tokenize("f(f2(1), f3())").is_ok());
assert!(tokenize("f(f2(1), f3(), a)").is_ok());
assert!(tokenize("f(a, b, f2(), f3(), c)").is_ok());
assert!(tokenize("-sin(pi * 3)^ cos(2) / Func2(x, f(), z) * _buildIN()").is_ok());
}
#[test]
fn test_show_latex() {
//let test_token = tokenize("x1^2-10*x1+x2^2+8<=5*2").unwrap();
//let test_token = tokenize("max((5*1)*x1+3*x2+2*x3+(10-3)*x4+4*x5)").unwrap();
//let test_token = tokenize("1*3*x2+sin(8-2)*x3 - cos(pi)< 7").unwrap();
//let test_token = tokenize("x1%5+3/3*x2+min(2,5)*x3*2e19 && 1").unwrap();
//let test_token = tokenize("2!").unwrap();
let test_token = tokenize("~x1").unwrap();
println!("{:?}", test_token);
for x in test_token {
println!("{}", x);
}
}
#[test]
fn test_tensor() {
assert_eq!(
tokenize("[3]"),
Ok(vec![Token::Tensor(None), Token::Number(3.), Token::RBracket])
);
assert!(tokenize("[[1,2],[3,4]]").is_ok());
}
}
}
\ No newline at end of file
// flowing should as same as in sparrowzz
use ndarray::{arr1, Array1, Array2, Axis, IxDyn, SliceInfo, SliceInfoElem};
use num_traits::ToPrimitive;
use crate::{FuncEvalError, MyCx, MyF};
#[cfg(feature = "enable_ndarray_blas")]
use ndarray::{array};
#[cfg(feature = "enable_ndarray_blas")]
use ndarray_linalg::*;
// #[cfg(feature = "enable_ndarray_blas")]
use num_complex::Complex64;
pub trait TsLinalgFn {
......@@ -369,12 +365,12 @@ impl TsfnBasic {
MyF::Tensor(t) => {
if t.ndim() > 1 {
if t.shape().len() == 2 && (t.shape()[0] == 1 || t.shape()[1] == 1) {
Ok(MyF::Tensor(Array2::from_diag(&arr1(t.clone().into_raw_vec().as_slice())).into_dyn()))
Ok(MyF::Tensor(Array2::from_diag(&arr1(t.clone().into_raw_vec_and_offset().0.as_slice())).into_dyn()))
} else {
Ok(MyF::Tensor(t.diag().into_dyn().to_owned()))
}
} else {
Ok(MyF::Tensor(Array2::from_diag(&arr1(t.clone().into_raw_vec().as_slice())).into_dyn()))
Ok(MyF::Tensor(Array2::from_diag(&arr1(t.clone().into_raw_vec_and_offset().0.as_slice())).into_dyn()))
}
}
}
......@@ -386,12 +382,12 @@ impl TsfnBasic {
MyCx::Tensor(t) => {
if t.ndim() > 1 {
if t.shape().len() == 2 && t.shape()[1] == 1 {
Ok(MyCx::Tensor(Array2::from_diag(&arr1(t.clone().into_raw_vec().as_slice())).into_dyn()))
Ok(MyCx::Tensor(Array2::from_diag(&arr1(t.clone().into_raw_vec_and_offset().0.as_slice())).into_dyn()))
} else {
Ok(MyCx::Tensor(t.diag().into_dyn().to_owned()))
}
} else {
Ok(MyCx::Tensor(Array2::from_diag(&arr1(t.clone().into_raw_vec().as_slice())).into_dyn()))
Ok(MyCx::Tensor(Array2::from_diag(&arr1(t.clone().into_raw_vec_and_offset().0.as_slice())).into_dyn()))
}
}
}
......@@ -560,9 +556,9 @@ impl TsfnBasic {
}
if eq_size {
let mut matrix = Array2::zeros([m.to_usize().unwrap(), n.to_usize().unwrap()]);
let i_vec = i.to_owned().into_raw_vec();
let j_vec = j.to_owned().into_raw_vec();
let v_vec = v.to_owned().into_raw_vec();
let i_vec = i.to_owned().into_raw_vec_and_offset().0;
let j_vec = j.to_owned().into_raw_vec_and_offset().0;
let v_vec = v.to_owned().into_raw_vec_and_offset().0;
for k in 0..i.len() {
if i_vec[k] >= *m {
return Err(FuncEvalError::NumberArgs(0))
......@@ -609,59 +605,4 @@ impl TsfnBasic {
}
}
}
impl TsLinalgFn for TsfnBasic {
#[cfg(feature = "enable_ndarray_blas")]
fn ts_eig(args: &[MyCx]) -> Result<MyCx, FuncEvalError> {
match &args[0] {
MyCx::F64(f) => Ok(MyCx::Tensor(array![*f, Complex64::new(1., 0.)].into_dyn())),
MyCx::Tensor(t) => {
if t.ndim() == 2 && t.shape()[0] == t.shape()[1] {
match t.clone().into_dimensionality() {
Ok(t2) => {
let (eigs, _) = t2.eig().map_err(|_| FuncEvalError::UnknownFunction)?;
Ok(MyCx::Tensor(eigs.into_dyn()))
},
Err(_) => Err(FuncEvalError::NumberArgs(0)),
}
} else {
Err(FuncEvalError::NumberArgs(0))
}
}
}
}
#[cfg(feature = "enable_ndarray_blas")]
fn ts_trace(args: &[MyF]) -> Result<MyF, FuncEvalError> {
match &args[0] {
MyF::F64(f) => Ok(MyF::F64(*f)),
MyF::Tensor(t) => {
if t.ndim() == 2 && t.shape()[0] == t.shape()[1] {
match t.clone().into_dimensionality() {
Ok(t2) => Ok(MyF::F64(t2.trace().map_err(|_| FuncEvalError::UnknownFunction)?)),
Err(_) => Err(FuncEvalError::NumberArgs(0)),
}
} else {
Err(FuncEvalError::NumberArgs(0))
}
}
}
}
#[cfg(feature = "enable_ndarray_blas")]
fn ts_trace_cx(args: &[MyCx]) -> Result<MyCx, FuncEvalError> {
match &args[0] {
MyCx::F64(f) => Ok(MyCx::F64(*f)),
MyCx::Tensor(t) => {
if t.ndim() == 2 && t.shape()[0] == t.shape()[1] {
match t.clone().into_dimensionality() {
Ok(t2) => Ok(MyCx::F64(t2.trace().map_err(|_| FuncEvalError::UnknownFunction)?)),
Err(_) => Err(FuncEvalError::NumberArgs(0)),
}
} else {
Err(FuncEvalError::NumberArgs(0))
}
}
}
}
}
// above should as same as in sparrowzz
\ No newline at end of file
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论