Commit 868c9b74 by dongshufeng

refactor(all): remove some dependencies

parent c36d4633
......@@ -246,7 +246,7 @@ dependencies = [
"log",
"mems",
"nalgebra",
"ndarray",
"ndarray 0.15.6",
"num-complex",
"serde_json",
]
......@@ -273,7 +273,7 @@ dependencies = [
"eig-domain",
"log",
"mems",
"ndarray",
"ndarray 0.15.6",
"petgraph",
"serde_json",
]
......@@ -346,7 +346,7 @@ name = "eig-expr"
version = "0.1.0"
dependencies = [
"fnv",
"ndarray",
"ndarray 0.16.0",
"nom",
"num-complex",
"num-traits",
......@@ -600,6 +600,21 @@ dependencies = [
]
[[package]]
name = "ndarray"
version = "0.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "087ee1ca8a7c22830c2bba4a96ed8e72ce0968ae944349324d52522f66aa3944"
dependencies = [
"matrixmultiply",
"num-complex",
"num-integer",
"num-traits",
"portable-atomic",
"portable-atomic-util",
"rawpointer",
]
[[package]]
name = "nom"
version = "7.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
......@@ -692,6 +707,21 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bda66fc9667c18cb2758a2ac84d1167245054bcf85d5d1aaa6923f45801bdd02"
[[package]]
name = "portable-atomic"
version = "1.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "da544ee218f0d287a911e9c99a39a8c9bc8fcad3cb8db5959940044ecfc67265"
[[package]]
name = "portable-atomic-util"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fcdd8420072e66d54a407b3316991fe946ce3ab1083a7f575b2463866624704d"
dependencies = [
"portable-atomic",
]
[[package]]
name = "proc-macro2"
version = "1.0.86"
source = "registry+https://github.com/rust-lang/crates.io-index"
......
......@@ -296,3 +296,4 @@ impl PartialEq for AoeModel {
b
}
}
// above should as same as in sparrowzz
\ No newline at end of file
......@@ -13,4 +13,4 @@ rayon = "1.10"
rustc-hash = "2.0"
num-traits = "0.2"
num-complex = "0.4"
ndarray = "0.15"
\ No newline at end of file
ndarray = "0.16"
\ No newline at end of file
......@@ -625,71 +625,6 @@ impl Deref for Expr {
}
}
/// A trait of a source of variables (and constants) and functions for substitution into an
/// evaluated expression.
///
/// A simplest way to create a custom context provider is to use [`Context`](struct.Context.html).
///
/// ## Advanced usage
///
/// Alternatively, values of variables/constants can be specified by tuples `(name, value)`,
/// `std::collections::HashMap` or `std::collections::BTreeMap`.
///
/// use {ContextProvider, Context};
///
/// let mut ctx = Context::new(); // built-ins
/// ctx.var("x", 2.); // insert a new variable
/// assert_eq!(ctx.get_var("pi"), Some(std::f64::consts::PI));
///
/// let myvars = ("x", 2.); // tuple as a ContextProvider
/// assert_eq!(myvars.get_var("x"), Some(2f64));
///
/// // HashMap as a ContextProvider
/// let mut varmap = std::collections::HashMap::new();
/// varmap.insert("x", 2.);
/// varmap.insert("y", 3.);
/// assert_eq!(varmap.get_var("x"), Some(2f64));
/// assert_eq!(varmap.get_var("z"), None);
///
/// Custom functions can be also defined.
///
/// use {ContextProvider, Context};
///
/// let mut ctx = Context::new(); // built-ins
/// ctx.func2("phi", |x, y| x / (y * y));
///
/// assert_eq!(ctx.eval_func("phi", &[2., 3.]), Ok(2. / (3. * 3.)));
///
/// A `ContextProvider` can be built by combining other contexts:
///
/// use Context;
///
/// let bins = Context::new(); // built-ins
/// let mut funcs = Context::empty(); // empty context
/// funcs.func2("phi", |x, y| x / (y * y));
/// let myvars = ("x", 2.);
///
/// // contexts can be combined using tuples
/// let ctx = ((myvars, bins), funcs); // first context has preference if there's duplicity
///
/// assert_eq!(eval_str_with_context("x * pi + phi(1., 2.)", ctx).unwrap(), 2. *
/// std::f64::consts::PI + 1. / (2. * 2.));
///
#[doc(hidden)]
#[cfg(feature = "with_rand")]
pub fn random() -> f64 {
use rand::Rng;
rand::thread_rng().gen::<f64>()
}
#[doc(hidden)]
#[cfg(feature = "with_rand")]
pub fn random2(lower: f64, upper: f64) -> f64 {
use rand::Rng;
rand::thread_rng().gen_range(lower..upper)
}
#[doc(hidden)]
pub fn max_array(xs: &[f64]) -> f64 {
xs.iter().fold(f64::NEG_INFINITY, |m, &x| m.max(x))
......@@ -1064,8 +999,6 @@ impl<'a> Context<'a> {
ctx.var("pi", consts::PI);
ctx.var("PI", consts::PI);
ctx.var("e", consts::E);
#[cfg(feature = "with_rand")]
ctx.func0("rand", random);
ctx.func1("sqrt", f64::sqrt);
ctx.func1("exp", f64::exp);
ctx.func1("ln", f64::ln);
......@@ -1088,8 +1021,6 @@ impl<'a> Context<'a> {
ctx.func1("round", f64::round);
ctx.func1("signum", f64::signum);
ctx.func2("atan2", f64::atan2);
#[cfg(feature = "with_rand")]
ctx.func2("rand2", random2);
ctx.funcn("max", max_array, 1..);
ctx.funcn("min", min_array, 1..);
ctx
......@@ -1303,151 +1234,4 @@ impl<'a> ContextProvider for Context<'a> {
.get(name)
.map_or(Err(FuncEvalError::UnknownFunction), |f| f(args))
}
}
#[cfg(test)]
mod tests {
use std::str::FromStr;
use Error;
use super::*;
#[test]
fn test_eval() {
assert_eq!(eval_str("3 -3"), Ok(0.));
assert_eq!(eval_str("2 + 3"), Ok(5.));
assert_eq!(eval_str("2 + (3 + 4)"), Ok(9.));
assert_eq!(eval_str("-2^(4 - 3) * (3 + 4)"), Ok(-14.));
assert_eq!(eval_str("-2*3! + 1"), Ok(-11.));
assert_eq!(eval_str("-171!"), Ok(f64::MIN));
assert_eq!(eval_str("150!/148!"), Ok(22350.));
assert_eq!(eval_str("a + 3"), Err(Error::UnknownVariable("a".into())));
assert_eq!(eval_str("round(sin (pi) * cos(0))"), Ok(0.));
assert_eq!(eval_str("round( sqrt(3^2 + 4^2)) "), Ok(5.));
assert_eq!(eval_str("max(1.)"), Ok(1.));
assert_eq!(eval_str("max(1., 2., -1)"), Ok(2.));
assert_eq!(eval_str("min(1., 2., -1)"), Ok(-1.));
assert_eq!(
eval_str("sin(1.) + cos(2.)"),
Ok((1f64).sin() + (2f64).cos())
);
assert_eq!(eval_str("10 % 9"), Ok(10f64 % 9f64));
match eval_str("0.5!") {
Err(Error::EvalError(_)) => {}
_ => panic!("Cannot evaluate factorial of non-integer"),
}
}
#[test]
fn test_builtins() {
assert_eq!(eval_str("atan2(1.,2.)"), Ok((1f64).atan2(2.)));
}
#[test]
fn test_eval_func_ctx() {
use std::collections::{BTreeMap, HashMap};
let y = 5.;
assert_eq!(
eval_str_with_context("phi(2.)", Context::new().func1("phi", |x| x + y + 3.)),
Ok(2. + y + 3.)
);
assert_eq!(
eval_str_with_context(
"phi(2., 3.)",
Context::new().func2("phi", |x, y| x + y + 3.),
),
Ok(2. + 3. + 3.)
);
assert_eq!(
eval_str_with_context(
"phi(2., 3., 4.)",
Context::new().func3("phi", |x, y, z| x + y * z),
),
Ok(2. + 3. * 4.)
);
assert_eq!(
eval_str_with_context(
"phi(2., 3.)",
Context::new().funcn("phi", |xs: &[f64]| xs[0] + xs[1], 2),
),
Ok(2. + 3.)
);
let mut m = HashMap::new();
m.insert("x", 2.);
m.insert("y", 3.);
assert_eq!(eval_str_with_context("x + y", &m), Ok(2. + 3.));
assert_eq!(
eval_str_with_context("x + z", m),
Err(Error::UnknownVariable("z".into()))
);
let mut m = BTreeMap::new();
m.insert("x", 2.);
m.insert("y", 3.);
assert_eq!(eval_str_with_context("x + y", &m), Ok(2. + 3.));
assert_eq!(
eval_str_with_context("x + z", m),
Err(Error::UnknownVariable("z".into()))
);
}
#[test]
fn test_bind() {
let expr = Expr::from_str("x + 3").unwrap();
let func = expr.clone().bind("x").unwrap();
assert_eq!(func(1.), 4.);
assert_eq!(
expr.clone().bind("y").err(),
Some(Error::UnknownVariable("x".into()))
);
let ctx = (("x", 2.), builtin());
let func = expr.bind_with_context(&ctx, "y").unwrap();
assert_eq!(func(1.), 5.);
let expr = Expr::from_str("x + y + 2.").unwrap();
let func = expr.clone().bind2("x", "y").unwrap();
assert_eq!(func(1., 2.), 5.);
assert_eq!(
expr.clone().bind2("z", "y").err(),
Some(Error::UnknownVariable("x".into()))
);
assert_eq!(
expr.bind2("x", "z").err(),
Some(Error::UnknownVariable("y".into()))
);
let expr = Expr::from_str("x + y^2 + z^3").unwrap();
let func = expr.bind3("x", "y", "z").unwrap();
assert_eq!(func(1., 2., 3.), 32.);
let expr = Expr::from_str("sin(x)").unwrap();
let func = expr.bind("x").unwrap();
assert_eq!(func(1.), (1f64).sin());
let expr = Expr::from_str("sin(x,2)").unwrap();
match expr.bind("x") {
Err(Error::Function(_, FuncEvalError::NumberArgs(1))) => {}
_ => panic!("bind did not error"),
}
let expr = Expr::from_str("hey(x,2)").unwrap();
match expr.bind("x") {
Err(Error::Function(_, FuncEvalError::UnknownFunction)) => {}
_ => panic!("bind did not error"),
}
}
#[test]
fn hash_context() {
let y = 0.;
{
let z = 0.;
let mut ctx = Context::new();
ctx.var("x", 1.).func1("f", |x| x + y).func1("g", |x| x + z);
ctx.func2("g", |x, y| x + y);
}
}
}
}
\ No newline at end of file
......@@ -302,19 +302,6 @@ pub fn new_cx_angle(r: Complex64, i: Complex64) -> Complex64 {
}
#[doc(hidden)]
#[cfg(feature = "with_rand")]
pub fn random() -> Complex64 {
use rand::Rng;
Complex64::new(rand::thread_rng().gen::<f64>(), 0.)
}
#[doc(hidden)]
#[cfg(feature = "with_rand")]
pub fn random2(lower: Complex64, upper: Complex64) -> Complex64 {
use rand::Rng;
Complex64::new(rand::thread_rng().gen_range(lower.re..upper.re), 0.)
}
#[doc(hidden)]
pub fn abs(v: Complex64) -> Complex64 {
Complex64::new(v.norm(), 0.)
}
......@@ -395,8 +382,6 @@ impl<'a> ContextCx<'a> {
ctx.var("pi", PI);
ctx.var("PI", PI);
ctx.var("e", std::f64::consts::E);
#[cfg(feature = "with_rand")]
ctx.func0("rand", random);
ctx.func1("abs", abs);
ctx.func1("sqrt", Complex64::sqrt);
ctx.func1("exp", Complex64::exp);
......@@ -429,8 +414,6 @@ impl<'a> ContextCx<'a> {
ctx.func2("c1", new_cx_rad);
// 用角度建立复数
ctx.func2("c2", new_cx_angle);
#[cfg(feature = "with_rand")]
ctx.func2("rand2", random2);
ctx.funcn("max", max_array, 1..);
ctx.funcn("min", min_array, 1..);
ctx
......@@ -616,117 +599,4 @@ impl<'a> ContextProvider for ContextCx<'a> {
.get(name)
.map_or(Err(FuncEvalError::UnknownFunction), |f| f(args))
}
}
#[cfg(test)]
mod tests {
use std::f64::consts::PI;
use std::ops::Mul;
use std::str::FromStr;
use approx::assert_relative_eq;
use ndarray::array;
use num_complex::{Complex, Complex64};
use crate::Expr;
use crate::expr_complex::ContextCx;
#[test]
fn it_works() {
let expr = Expr::from_str("1+2").unwrap();
let r = expr.eval_complex();
assert_eq!(r, Ok(Complex64::new(3., 0.)));
let expr = Expr::from_str("c(0,1)+c(2,0.)").unwrap();
let r = expr.eval_complex();
assert_eq!(r, Ok(Complex64::new(2., 1.)));
let expr = Expr::from_str("abs(c(0,1))+c(2,0.)").unwrap();
let r = expr.eval_complex();
assert_eq!(r, Ok(Complex64::new(3., 0.)));
let mut cc = ContextCx::new();
cc.var_cx("a", Complex::new(1., 1.));
let expr = Expr::from_str("a+c(2,0.)").unwrap();
let r = expr.eval_complex_with_ctx(cc.clone());
assert_eq!(r, Ok(Complex64::new(3., 1.)));
}
#[test]
fn test_2_1() {
let mut cc = ContextCx::new();
cc.var("GMRabc", 0.00744);
cc.var("GMRn", 0.00248);
cc.var("rabc", 0.190);
cc.var("rn", 0.368);
cc.var("Dab", 0.7622);
cc.var("Dbc", 1.3720);
cc.var("Dca", 2.1342);
cc.var("Dan", 1.7247);
cc.var("Dbn", 1.3025);
cc.var("Dcn", 1.5244);
// let zaa = Complex64::new(rabc + 0.0493, 0.0628 * (((1.0 / GMRabc) as f64).ln() + 8.02517));
let expr =
Expr::from_str("c(rabc + 0.0493, 0.0628 * (ln(1.0 / GMRabc) + 8.02517))").unwrap();
let zaa = expr.eval_complex_with_ctx(cc.clone()).unwrap();
assert_relative_eq!(zaa.re, 0.2393, max_relative = 1e-4);
assert_relative_eq!(zaa.im, 0.8118, max_relative = 1e-4);
let expr = Expr::from_str("c(0.0493, 0.0628 * (ln(1.0 / Dab) + 8.02517))").unwrap();
let zab = expr.eval_complex_with_ctx(cc.clone()).unwrap();
assert_relative_eq!(zab.im, 0.5210, max_relative = 1e-4);
let expr = Expr::from_str("c(0.0493, 0.0628 * (ln(1.0 / Dca) + 8.02517))").unwrap();
let zac = expr.eval_complex_with_ctx(cc.clone()).unwrap();
assert_relative_eq!(zac.im, 0.4564, max_relative = 1e-4);
let expr = Expr::from_str("c(0.0493, 0.0628 * (ln(1.0 / Dbc) + 8.02517))").unwrap();
let zbc = expr.eval_complex_with_ctx(cc.clone()).unwrap();
assert_relative_eq!(zbc.im, 0.4841, max_relative = 1e-4);
let expr = Expr::from_str("c(0.0493, 0.0628 * (ln(1.0 / Dan) + 8.02517))").unwrap();
let zan = expr.eval_complex_with_ctx(cc.clone()).unwrap();
assert_relative_eq!(zan.im, 0.4698, max_relative = 1e-3);
let expr = Expr::from_str("c(0.0493, 0.0628 * (ln(1.0 / Dbn) + 8.02517))").unwrap();
let zbn = expr.eval_complex_with_ctx(cc.clone()).unwrap();
assert_relative_eq!(zbn.im, 0.4874, max_relative = 1e-4);
let expr = Expr::from_str("c(0.0493, 0.0628 * (ln(1.0 / Dcn) + 8.02517))").unwrap();
let zcn = expr.eval_complex_with_ctx(cc.clone()).unwrap();
assert_relative_eq!(zcn.im, 0.4775, max_relative = 1e-4);
let expr = Expr::from_str("c(rn + 0.0493, 0.0628 * (ln(1.0 / GMRn) + 8.02517))").unwrap();
let znn = expr.eval_complex_with_ctx(cc.clone()).unwrap();
assert_relative_eq!(znn.re, 0.4173, max_relative = 1e-4);
assert_relative_eq!(znn.im, 0.8807, max_relative = 1e-4);
let zij = array![[zaa, zab, zac], [zab, zaa, zbc], [zac, zbc, zaa]];
let zin = array![[zan], [zbn], [zcn]];
let znj = array![zan, zbn, zcn];
let zabc = zij - zin.mul(array![Complex64::new(1.0, 0.0)] / znn).mul(znj);
println!("{:?}", zabc);
let a = Complex64::new(f64::cos(2.0 * PI / 3.0), f64::sin(2.0 * PI / 3.0));
let matrixes = array![
[
Complex64::new(1.0, 0.0),
Complex64::new(1.0, 0.0),
Complex64::new(1.0, 0.0)
],
[Complex64::new(1.0, 0.0), a * a, a],
[Complex64::new(1.0, 0.0), a, a * a]
];
let matrixes_inv = array![
[
Complex64::new(1.0 / 3.0, 0.0),
Complex64::new(1.0 / 3.0, 0.0),
Complex64::new(1.0 / 3.0, 0.0)
],
[
Complex64::new(1.0 / 3.0, 0.0),
a * Complex64::new(1.0 / 3.0, 0.0),
a * a * Complex64::new(1.0 / 3.0, 0.0)
],
[
Complex64::new(1.0 / 3.0, 0.0),
a * a * Complex64::new(1.0 / 3.0, 0.0),
a * Complex64::new(1.0 / 3.0, 0.0)
]
];
let temp = matrixes_inv.dot(&zabc);
let z012 = temp.dot(&matrixes);
assert_relative_eq!(z012.get([0, 0]).unwrap().re, 0.5050, max_relative = 1e-3);
}
}
}
\ No newline at end of file
//! Tokenizer that converts a mathematical expression in a string form into a series of `Token`s.
//!
//! The underlying parser is build using the [nom] parser combinator crate.
//!
//! The parser should tokenize only well-formed expressions.
//!
//! [nom]: https://crates.io/crates/nom
//!
use std;
use std::fmt;
use std::fmt::{Display, Formatter};
......@@ -344,299 +336,4 @@ impl Display for Token {
Token::Tensor(size) => write!(f, "Tensor({:?})", size),
}
}
}
#[cfg(test)]
mod tests {
use nom::error;
use nom::error::ErrorKind::{Alpha, Digit};
use nom::Err::Error;
use crate::ParseError;
use super::*;
#[test]
fn test_binop() {
assert_eq!(
binop(b"+"),
Ok((&b""[..], Token::Binary(Operation::Plus)))
);
assert_eq!(
binop(b"-"),
Ok((&b""[..], Token::Binary(Operation::Minus)))
);
assert_eq!(
binop(b"*"),
Ok((&b""[..], Token::Binary(Operation::Times)))
);
assert_eq!(
binop(b"/"),
Ok((&b""[..], Token::Binary(Operation::Div)))
);
}
#[test]
fn test_number() {
assert_eq!(
number(b"32143"),
Ok((&b""[..], Token::Number(32143f64)))
);
assert_eq!(
number(b"2."),
Ok((&b""[..], Token::Number(2.0f64)))
);
assert_eq!(
number(b"32143.25"),
Ok((&b""[..], Token::Number(32143.25f64)))
);
assert_eq!(
number(b"0.125e9"),
Ok((&b""[..], Token::Number(0.125e9f64)))
);
assert_eq!(
number(b"20.5E-3"),
Ok((&b""[..], Token::Number(20.5E-3f64)))
);
assert_eq!(
number(b"123423e+50"),
Ok((&b""[..], Token::Number(123423e+50f64)))
);
assert_eq!(
number(b""),
Err(Error(error::Error {
input: &b""[..],
code: Digit
}))
);
assert_eq!(
number(b".2"),
Err(Error(error::Error {
input: &b".2"[..],
code: Digit
}))
);
assert_eq!(
number(b"+"),
Err(Error(error::Error {
input: &b"+"[..],
code: Digit
}))
);
assert_eq!(
number(b"e"),
Err(Error(error::Error {
input: &b"e"[..],
code: Digit
}))
);
assert_eq!(
number(b"1E"),
Err(Error(error::Error {
input: &b""[..],
code: Digit
}))
);
assert_eq!(
number(b"1e+"),
Err(Error(error::Error {
input: &b"+"[..],
code: Digit
}))
);
}
#[test]
fn test_var() {
for &s in ["abc", "U0", "_034", "a_be45EA", "aAzZ_"].iter() {
assert_eq!(
var(s.as_bytes()),
Ok((&b""[..], Token::Var(s.into())))
);
}
for &s in ["\'a\'", "\"U0\"", "\"_034\"", "'*'", "\"+\""].iter() {
assert_eq!(
var(s.as_bytes()),
Ok((&b""[..], tokenize(s).unwrap()[0].clone()))
);
}
assert_eq!(
var(b""),
Err(Error(error::Error {
input: &b""[..],
code: Alpha
}))
);
assert_eq!(
var(b"0"),
Err(Error(error::Error {
input: &b"0"[..],
code: Alpha
}))
);
}
#[test]
fn test_func() {
for &s in ["abc(", "u0(", "_034 (", "A_be45EA ("].iter() {
assert_eq!(
func(s.as_bytes()),
Ok((&b""[..], Token::Func(s[0..s.len() - 1].trim().into(), None)))
);
}
assert_eq!(
func(b""),
Err(Error(error::Error {
input: &b""[..],
code: Alpha
}))
);
assert_eq!(
func(b"("),
Err(Error(error::Error {
input: &b"("[..],
code: Alpha
}))
);
assert_eq!(
func(b"0("),
Err(Error(error::Error {
input: &b"0("[..],
code: Alpha
}))
);
}
#[test]
fn test_tokenize() {
use super::Operation::*;
use super::Token::*;
assert_eq!(tokenize("a"), Ok(vec![Var("a".into())]));
assert_eq!(
tokenize("2 +(3--2) "),
Ok(vec![
Number(2f64),
Binary(Plus),
LParen,
Number(3f64),
Binary(Minus),
Unary(Minus),
Number(2f64),
RParen,
])
);
assert_eq!(
tokenize("-2^ ab0 *12 - C_0"),
Ok(vec![
Unary(Minus),
Number(2f64),
Binary(Pow),
Var("ab0".into()),
Binary(Times),
Number(12f64),
Binary(Minus),
Var("C_0".into()),
])
);
assert_eq!(
tokenize("-sin(pi * 3)^ cos(2) / Func2(x, f(y), z) * _buildIN(y)"),
Ok(vec![
Unary(Minus),
Func("sin".into(), None),
Var("pi".into()),
Binary(Times),
Number(3f64),
RParen,
Binary(Pow),
Func("cos".into(), None),
Number(2f64),
RParen,
Binary(Div),
Func("Func2".into(), None),
Var("x".into()),
Comma,
Func("f".into(), None),
Var("y".into()),
RParen,
Comma,
Var("z".into()),
RParen,
Binary(Times),
Func("_buildIN".into(), None),
Var("y".into()),
RParen,
])
);
assert_eq!(
tokenize("2 % 3"),
Ok(vec![Number(2f64), Binary(Rem), Number(3f64)])
);
assert_eq!(
tokenize("1 + 3! + 1"),
Ok(vec![
Number(1f64),
Binary(Plus),
Number(3f64),
Unary(Fact),
Binary(Plus),
Number(1f64),
])
);
assert_eq!(tokenize("!3"), Err(ParseError::UnexpectedToken(2)));
assert_eq!(tokenize("()"), Err(ParseError::UnexpectedToken(1)));
assert_eq!(tokenize(""), Err(ParseError::MissingArgument));
assert_eq!(tokenize("2)"), Err(ParseError::UnexpectedToken(1)));
assert_eq!(tokenize("2^"), Err(ParseError::MissingArgument));
assert_eq!(tokenize("(((2)"), Err(ParseError::MissingRParen(2)));
assert_eq!(tokenize("f(2,)"), Err(ParseError::UnexpectedToken(1)));
assert_eq!(tokenize("f(,2)"), Err(ParseError::UnexpectedToken(3)));
}
#[test]
fn test_func_with_no_para() {
assert_eq!(
tokenize("f()"),
Ok(vec![Token::Func("f".to_string(), Some(0))])
);
assert_eq!(
tokenize("f( )"),
Ok(vec![Token::Func("f".to_string(), Some(0))])
);
assert!(tokenize("f(f2(1), f3())").is_ok());
assert!(tokenize("f(f2(1), f3(), a)").is_ok());
assert!(tokenize("f(a, b, f2(), f3(), c)").is_ok());
assert!(tokenize("-sin(pi * 3)^ cos(2) / Func2(x, f(), z) * _buildIN()").is_ok());
}
#[test]
fn test_show_latex() {
//let test_token = tokenize("x1^2-10*x1+x2^2+8<=5*2").unwrap();
//let test_token = tokenize("max((5*1)*x1+3*x2+2*x3+(10-3)*x4+4*x5)").unwrap();
//let test_token = tokenize("1*3*x2+sin(8-2)*x3 - cos(pi)< 7").unwrap();
//let test_token = tokenize("x1%5+3/3*x2+min(2,5)*x3*2e19 && 1").unwrap();
//let test_token = tokenize("2!").unwrap();
let test_token = tokenize("~x1").unwrap();
println!("{:?}", test_token);
for x in test_token {
println!("{}", x);
}
}
#[test]
fn test_tensor() {
assert_eq!(
tokenize("[3]"),
Ok(vec![Token::Tensor(None), Token::Number(3.), Token::RBracket])
);
assert!(tokenize("[[1,2],[3,4]]").is_ok());
}
}
}
\ No newline at end of file
// flowing should as same as in sparrowzz
use ndarray::{arr1, Array1, Array2, Axis, IxDyn, SliceInfo, SliceInfoElem};
use num_traits::ToPrimitive;
use crate::{FuncEvalError, MyCx, MyF};
#[cfg(feature = "enable_ndarray_blas")]
use ndarray::{array};
#[cfg(feature = "enable_ndarray_blas")]
use ndarray_linalg::*;
// #[cfg(feature = "enable_ndarray_blas")]
use num_complex::Complex64;
pub trait TsLinalgFn {
......@@ -369,12 +365,12 @@ impl TsfnBasic {
MyF::Tensor(t) => {
if t.ndim() > 1 {
if t.shape().len() == 2 && (t.shape()[0] == 1 || t.shape()[1] == 1) {
Ok(MyF::Tensor(Array2::from_diag(&arr1(t.clone().into_raw_vec().as_slice())).into_dyn()))
Ok(MyF::Tensor(Array2::from_diag(&arr1(t.clone().into_raw_vec_and_offset().0.as_slice())).into_dyn()))
} else {
Ok(MyF::Tensor(t.diag().into_dyn().to_owned()))
}
} else {
Ok(MyF::Tensor(Array2::from_diag(&arr1(t.clone().into_raw_vec().as_slice())).into_dyn()))
Ok(MyF::Tensor(Array2::from_diag(&arr1(t.clone().into_raw_vec_and_offset().0.as_slice())).into_dyn()))
}
}
}
......@@ -386,12 +382,12 @@ impl TsfnBasic {
MyCx::Tensor(t) => {
if t.ndim() > 1 {
if t.shape().len() == 2 && t.shape()[1] == 1 {
Ok(MyCx::Tensor(Array2::from_diag(&arr1(t.clone().into_raw_vec().as_slice())).into_dyn()))
Ok(MyCx::Tensor(Array2::from_diag(&arr1(t.clone().into_raw_vec_and_offset().0.as_slice())).into_dyn()))
} else {
Ok(MyCx::Tensor(t.diag().into_dyn().to_owned()))
}
} else {
Ok(MyCx::Tensor(Array2::from_diag(&arr1(t.clone().into_raw_vec().as_slice())).into_dyn()))
Ok(MyCx::Tensor(Array2::from_diag(&arr1(t.clone().into_raw_vec_and_offset().0.as_slice())).into_dyn()))
}
}
}
......@@ -560,9 +556,9 @@ impl TsfnBasic {
}
if eq_size {
let mut matrix = Array2::zeros([m.to_usize().unwrap(), n.to_usize().unwrap()]);
let i_vec = i.to_owned().into_raw_vec();
let j_vec = j.to_owned().into_raw_vec();
let v_vec = v.to_owned().into_raw_vec();
let i_vec = i.to_owned().into_raw_vec_and_offset().0;
let j_vec = j.to_owned().into_raw_vec_and_offset().0;
let v_vec = v.to_owned().into_raw_vec_and_offset().0;
for k in 0..i.len() {
if i_vec[k] >= *m {
return Err(FuncEvalError::NumberArgs(0))
......@@ -609,59 +605,4 @@ impl TsfnBasic {
}
}
}
impl TsLinalgFn for TsfnBasic {
#[cfg(feature = "enable_ndarray_blas")]
fn ts_eig(args: &[MyCx]) -> Result<MyCx, FuncEvalError> {
match &args[0] {
MyCx::F64(f) => Ok(MyCx::Tensor(array![*f, Complex64::new(1., 0.)].into_dyn())),
MyCx::Tensor(t) => {
if t.ndim() == 2 && t.shape()[0] == t.shape()[1] {
match t.clone().into_dimensionality() {
Ok(t2) => {
let (eigs, _) = t2.eig().map_err(|_| FuncEvalError::UnknownFunction)?;
Ok(MyCx::Tensor(eigs.into_dyn()))
},
Err(_) => Err(FuncEvalError::NumberArgs(0)),
}
} else {
Err(FuncEvalError::NumberArgs(0))
}
}
}
}
#[cfg(feature = "enable_ndarray_blas")]
fn ts_trace(args: &[MyF]) -> Result<MyF, FuncEvalError> {
match &args[0] {
MyF::F64(f) => Ok(MyF::F64(*f)),
MyF::Tensor(t) => {
if t.ndim() == 2 && t.shape()[0] == t.shape()[1] {
match t.clone().into_dimensionality() {
Ok(t2) => Ok(MyF::F64(t2.trace().map_err(|_| FuncEvalError::UnknownFunction)?)),
Err(_) => Err(FuncEvalError::NumberArgs(0)),
}
} else {
Err(FuncEvalError::NumberArgs(0))
}
}
}
}
#[cfg(feature = "enable_ndarray_blas")]
fn ts_trace_cx(args: &[MyCx]) -> Result<MyCx, FuncEvalError> {
match &args[0] {
MyCx::F64(f) => Ok(MyCx::F64(*f)),
MyCx::Tensor(t) => {
if t.ndim() == 2 && t.shape()[0] == t.shape()[1] {
match t.clone().into_dimensionality() {
Ok(t2) => Ok(MyCx::F64(t2.trace().map_err(|_| FuncEvalError::UnknownFunction)?)),
Err(_) => Err(FuncEvalError::NumberArgs(0)),
}
} else {
Err(FuncEvalError::NumberArgs(0))
}
}
}
}
}
// above should as same as in sparrowzz
\ No newline at end of file
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论