Generator for SVE intrinsics.
Co-authored-by: Jamie Cunliffe <Jamie.Cunliffe@arm.com> Co-authored-by: Jacob Bramley <jacob.bramley@arm.com> Co-authored-by: Luca Vizzarro <Luca.Vizzarro@arm.com> Co-authored-by: Adam Gemmell <adam.gemmell@arm.com>
This commit is contained in:
parent
9e24b307df
commit
03e4f2636e
14 changed files with 6197 additions and 0 deletions
|
|
@ -3,6 +3,18 @@ resolver = "1"
|
|||
members = [
|
||||
"crates/*",
|
||||
"examples"
|
||||
"crates/stdarch-verify",
|
||||
"crates/core_arch",
|
||||
"crates/std_detect",
|
||||
"crates/stdarch-gen-arm",
|
||||
"crates/stdarch-gen-loongarch",
|
||||
"crates/stdarch-gen",
|
||||
"crates/stdarch-gen2",
|
||||
"crates/intrinsic-test",
|
||||
"examples/"
|
||||
]
|
||||
exclude = [
|
||||
"crates/wasm-assert-instr-tests"
|
||||
]
|
||||
|
||||
[profile.release]
|
||||
|
|
|
|||
22
library/stdarch/crates/stdarch-gen2/Cargo.toml
Normal file
22
library/stdarch/crates/stdarch-gen2/Cargo.toml
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
[package]
|
||||
name = "stdarch-gen2"
|
||||
version = "0.1.0"
|
||||
authors = ["Luca Vizzarro <luca.vizzarro@arm.com>",
|
||||
"Jamie Cunliffe <Jamie.Cunliffe@arm.com>",
|
||||
"Adam Gemmell <Adam.Gemmell@arm.com",
|
||||
"Jacob Bramley <jacob.bramley@arm.com>"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
itertools = "0.10"
|
||||
lazy_static = "1.4.0"
|
||||
proc-macro2 = "1.0"
|
||||
quote = "1.0"
|
||||
regex = "1.5"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_with = "1.14"
|
||||
serde_yaml = "0.8"
|
||||
walkdir = "2.3.2"
|
||||
372
library/stdarch/crates/stdarch-gen2/src/assert_instr.rs
Normal file
372
library/stdarch/crates/stdarch-gen2/src/assert_instr.rs
Normal file
|
|
@ -0,0 +1,372 @@
|
|||
use proc_macro2::TokenStream;
|
||||
use quote::{format_ident, quote, ToTokens, TokenStreamExt};
|
||||
use serde::de::{self, MapAccess, Visitor};
|
||||
use serde::{ser::SerializeSeq, Deserialize, Deserializer, Serialize};
|
||||
use std::fmt;
|
||||
|
||||
use crate::{
|
||||
context::{self, Context},
|
||||
typekinds::{BaseType, BaseTypeKind},
|
||||
wildstring::WildString,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum InstructionAssertion {
|
||||
Basic(WildString),
|
||||
WithArgs(WildString, WildString),
|
||||
}
|
||||
|
||||
impl InstructionAssertion {
|
||||
fn build(&mut self, ctx: &Context) -> context::Result {
|
||||
match self {
|
||||
InstructionAssertion::Basic(ws) => ws.build_acle(ctx.local),
|
||||
InstructionAssertion::WithArgs(ws, args_ws) => [ws, args_ws]
|
||||
.into_iter()
|
||||
.try_for_each(|ws| ws.build_acle(ctx.local)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ToTokens for InstructionAssertion {
|
||||
fn to_tokens(&self, tokens: &mut TokenStream) {
|
||||
let instr = format_ident!(
|
||||
"{}",
|
||||
match self {
|
||||
Self::Basic(instr) => instr,
|
||||
Self::WithArgs(instr, _) => instr,
|
||||
}
|
||||
.to_string()
|
||||
);
|
||||
tokens.append_all(quote! { #instr });
|
||||
|
||||
if let Self::WithArgs(_, args) = self {
|
||||
let ex: TokenStream = args
|
||||
.to_string()
|
||||
.parse()
|
||||
.expect("invalid instruction assertion arguments expression given");
|
||||
tokens.append_all(quote! {, #ex})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Asserts that the given instruction is present for the intrinsic of the associated type bitsize.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(remote = "Self")]
|
||||
pub struct InstructionAssertionMethodForBitsize {
|
||||
pub default: InstructionAssertion,
|
||||
pub byte: Option<InstructionAssertion>,
|
||||
pub halfword: Option<InstructionAssertion>,
|
||||
pub word: Option<InstructionAssertion>,
|
||||
pub doubleword: Option<InstructionAssertion>,
|
||||
}
|
||||
|
||||
impl InstructionAssertionMethodForBitsize {
|
||||
fn build(&mut self, ctx: &Context) -> context::Result {
|
||||
if let Some(ref mut byte) = self.byte {
|
||||
byte.build(ctx)?
|
||||
}
|
||||
if let Some(ref mut halfword) = self.halfword {
|
||||
halfword.build(ctx)?
|
||||
}
|
||||
if let Some(ref mut word) = self.word {
|
||||
word.build(ctx)?
|
||||
}
|
||||
if let Some(ref mut doubleword) = self.doubleword {
|
||||
doubleword.build(ctx)?
|
||||
}
|
||||
self.default.build(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for InstructionAssertionMethodForBitsize {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
match self {
|
||||
InstructionAssertionMethodForBitsize {
|
||||
default: InstructionAssertion::Basic(instr),
|
||||
byte: None,
|
||||
halfword: None,
|
||||
word: None,
|
||||
doubleword: None,
|
||||
} => serializer.serialize_str(&instr.to_string()),
|
||||
InstructionAssertionMethodForBitsize {
|
||||
default: InstructionAssertion::WithArgs(instr, args),
|
||||
byte: None,
|
||||
halfword: None,
|
||||
word: None,
|
||||
doubleword: None,
|
||||
} => {
|
||||
let mut seq = serializer.serialize_seq(Some(2))?;
|
||||
seq.serialize_element(&instr.to_string())?;
|
||||
seq.serialize_element(&args.to_string())?;
|
||||
seq.end()
|
||||
}
|
||||
_ => InstructionAssertionMethodForBitsize::serialize(self, serializer),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for InstructionAssertionMethodForBitsize {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
struct IAMVisitor;
|
||||
|
||||
impl<'de> Visitor<'de> for IAMVisitor {
|
||||
type Value = InstructionAssertionMethodForBitsize;
|
||||
|
||||
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
|
||||
formatter.write_str("array, string or map")
|
||||
}
|
||||
|
||||
fn visit_str<E>(self, value: &str) -> Result<InstructionAssertionMethodForBitsize, E>
|
||||
where
|
||||
E: de::Error,
|
||||
{
|
||||
Ok(InstructionAssertionMethodForBitsize {
|
||||
default: InstructionAssertion::Basic(value.parse().map_err(E::custom)?),
|
||||
byte: None,
|
||||
halfword: None,
|
||||
word: None,
|
||||
doubleword: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
|
||||
where
|
||||
A: de::SeqAccess<'de>,
|
||||
{
|
||||
use serde::de::Error;
|
||||
let make_err =
|
||||
|| Error::custom("invalid number of arguments passed to assert_instruction");
|
||||
let instruction = seq.next_element()?.ok_or_else(make_err)?;
|
||||
let args = seq.next_element()?.ok_or_else(make_err)?;
|
||||
|
||||
if let Some(true) = seq.size_hint().map(|len| len > 0) {
|
||||
Err(make_err())
|
||||
} else {
|
||||
Ok(InstructionAssertionMethodForBitsize {
|
||||
default: InstructionAssertion::WithArgs(instruction, args),
|
||||
byte: None,
|
||||
halfword: None,
|
||||
word: None,
|
||||
doubleword: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn visit_map<M>(self, map: M) -> Result<InstructionAssertionMethodForBitsize, M::Error>
|
||||
where
|
||||
M: MapAccess<'de>,
|
||||
{
|
||||
InstructionAssertionMethodForBitsize::deserialize(
|
||||
de::value::MapAccessDeserializer::new(map),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
deserializer.deserialize_any(IAMVisitor)
|
||||
}
|
||||
}
|
||||
|
||||
/// Asserts that the given instruction is present for the intrinsic of the associated type.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(remote = "Self")]
|
||||
pub struct InstructionAssertionMethod {
|
||||
/// Instruction for integer intrinsics
|
||||
pub default: InstructionAssertionMethodForBitsize,
|
||||
/// Instruction for floating-point intrinsics (optional)
|
||||
#[serde(default)]
|
||||
pub float: Option<InstructionAssertionMethodForBitsize>,
|
||||
/// Instruction for unsigned integer intrinsics (optional)
|
||||
#[serde(default)]
|
||||
pub unsigned: Option<InstructionAssertionMethodForBitsize>,
|
||||
}
|
||||
|
||||
impl InstructionAssertionMethod {
|
||||
pub(crate) fn build(&mut self, ctx: &Context) -> context::Result {
|
||||
if let Some(ref mut float) = self.float {
|
||||
float.build(ctx)?
|
||||
}
|
||||
if let Some(ref mut unsigned) = self.unsigned {
|
||||
unsigned.build(ctx)?
|
||||
}
|
||||
self.default.build(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for InstructionAssertionMethod {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
match self {
|
||||
InstructionAssertionMethod {
|
||||
default:
|
||||
InstructionAssertionMethodForBitsize {
|
||||
default: InstructionAssertion::Basic(instr),
|
||||
byte: None,
|
||||
halfword: None,
|
||||
word: None,
|
||||
doubleword: None,
|
||||
},
|
||||
float: None,
|
||||
unsigned: None,
|
||||
} => serializer.serialize_str(&instr.to_string()),
|
||||
InstructionAssertionMethod {
|
||||
default:
|
||||
InstructionAssertionMethodForBitsize {
|
||||
default: InstructionAssertion::WithArgs(instr, args),
|
||||
byte: None,
|
||||
halfword: None,
|
||||
word: None,
|
||||
doubleword: None,
|
||||
},
|
||||
float: None,
|
||||
unsigned: None,
|
||||
} => {
|
||||
let mut seq = serializer.serialize_seq(Some(2))?;
|
||||
seq.serialize_element(&instr.to_string())?;
|
||||
seq.serialize_element(&args.to_string())?;
|
||||
seq.end()
|
||||
}
|
||||
_ => InstructionAssertionMethod::serialize(self, serializer),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for InstructionAssertionMethod {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
struct IAMVisitor;
|
||||
|
||||
impl<'de> Visitor<'de> for IAMVisitor {
|
||||
type Value = InstructionAssertionMethod;
|
||||
|
||||
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
|
||||
formatter.write_str("array, string or map")
|
||||
}
|
||||
|
||||
fn visit_str<E>(self, value: &str) -> Result<InstructionAssertionMethod, E>
|
||||
where
|
||||
E: de::Error,
|
||||
{
|
||||
Ok(InstructionAssertionMethod {
|
||||
default: InstructionAssertionMethodForBitsize {
|
||||
default: InstructionAssertion::Basic(value.parse().map_err(E::custom)?),
|
||||
byte: None,
|
||||
halfword: None,
|
||||
word: None,
|
||||
doubleword: None,
|
||||
},
|
||||
float: None,
|
||||
unsigned: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
|
||||
where
|
||||
A: de::SeqAccess<'de>,
|
||||
{
|
||||
use serde::de::Error;
|
||||
let make_err =
|
||||
|| Error::custom("invalid number of arguments passed to assert_instruction");
|
||||
let instruction = seq.next_element()?.ok_or_else(make_err)?;
|
||||
let args = seq.next_element()?.ok_or_else(make_err)?;
|
||||
|
||||
if let Some(true) = seq.size_hint().map(|len| len > 0) {
|
||||
Err(make_err())
|
||||
} else {
|
||||
Ok(InstructionAssertionMethod {
|
||||
default: InstructionAssertionMethodForBitsize {
|
||||
default: InstructionAssertion::WithArgs(instruction, args),
|
||||
byte: None,
|
||||
halfword: None,
|
||||
word: None,
|
||||
doubleword: None,
|
||||
},
|
||||
float: None,
|
||||
unsigned: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn visit_map<M>(self, map: M) -> Result<InstructionAssertionMethod, M::Error>
|
||||
where
|
||||
M: MapAccess<'de>,
|
||||
{
|
||||
InstructionAssertionMethod::deserialize(de::value::MapAccessDeserializer::new(map))
|
||||
}
|
||||
}
|
||||
|
||||
deserializer.deserialize_any(IAMVisitor)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct InstructionAssertionsForBaseType<'a>(
|
||||
pub &'a Vec<InstructionAssertionMethod>,
|
||||
pub &'a Option<&'a BaseType>,
|
||||
);
|
||||
|
||||
impl<'a> ToTokens for InstructionAssertionsForBaseType<'a> {
|
||||
fn to_tokens(&self, tokens: &mut TokenStream) {
|
||||
self.0.iter().for_each(
|
||||
|InstructionAssertionMethod {
|
||||
default,
|
||||
float,
|
||||
unsigned,
|
||||
}| {
|
||||
let kind = self.1.map(|ty| ty.kind());
|
||||
let instruction = match (kind, float, unsigned) {
|
||||
(None, float, unsigned) if float.is_some() || unsigned.is_some() => {
|
||||
unreachable!(
|
||||
"cannot determine the base type kind for instruction assertion: {self:#?}")
|
||||
}
|
||||
(Some(BaseTypeKind::Float), Some(float), _) => float,
|
||||
(Some(BaseTypeKind::UInt), _, Some(unsigned)) => unsigned,
|
||||
_ => default,
|
||||
};
|
||||
|
||||
let bitsize = self.1.and_then(|ty| ty.get_size().ok());
|
||||
let instruction = match (bitsize, instruction) {
|
||||
(
|
||||
Some(8),
|
||||
InstructionAssertionMethodForBitsize {
|
||||
byte: Some(byte), ..
|
||||
},
|
||||
) => byte,
|
||||
(
|
||||
Some(16),
|
||||
InstructionAssertionMethodForBitsize {
|
||||
halfword: Some(halfword),
|
||||
..
|
||||
},
|
||||
) => halfword,
|
||||
(
|
||||
Some(32),
|
||||
InstructionAssertionMethodForBitsize {
|
||||
word: Some(word), ..
|
||||
},
|
||||
) => word,
|
||||
(
|
||||
Some(64),
|
||||
InstructionAssertionMethodForBitsize {
|
||||
doubleword: Some(doubleword),
|
||||
..
|
||||
},
|
||||
) => doubleword,
|
||||
(_, InstructionAssertionMethodForBitsize { default, .. }) => default,
|
||||
};
|
||||
|
||||
tokens.append_all(quote! { #[cfg_attr(test, assert_instr(#instruction))]})
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
249
library/stdarch/crates/stdarch-gen2/src/context.rs
Normal file
249
library/stdarch/crates/stdarch-gen2/src/context.rs
Normal file
|
|
@ -0,0 +1,249 @@
|
|||
use itertools::Itertools;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::{
|
||||
expression::Expression,
|
||||
input::{InputSet, InputType},
|
||||
intrinsic::{Constraint, Intrinsic, Signature},
|
||||
matching::SizeMatchable,
|
||||
predicate_forms::PredicateForm,
|
||||
typekinds::{ToRepr, TypeKind},
|
||||
wildcards::Wildcard,
|
||||
wildstring::WildString,
|
||||
};
|
||||
|
||||
/// Maximum SVE vector size
|
||||
const SVE_VECTOR_MAX_SIZE: u32 = 2048;
|
||||
/// Vector register size
|
||||
const VECTOR_REG_SIZE: u32 = 128;
|
||||
|
||||
/// Generator result
|
||||
pub type Result<T = ()> = std::result::Result<T, String>;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ArchitectureSettings {
|
||||
#[serde(alias = "arch")]
|
||||
pub arch_name: String,
|
||||
pub target_feature: Vec<String>,
|
||||
#[serde(alias = "llvm_prefix")]
|
||||
pub llvm_link_prefix: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GlobalContext {
|
||||
pub arch_cfgs: Vec<ArchitectureSettings>,
|
||||
#[serde(default)]
|
||||
pub uses_neon_types: bool,
|
||||
}
|
||||
|
||||
/// Context of an intrinsic group
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct GroupContext {
|
||||
/// LLVM links to target input sets
|
||||
pub links: HashMap<String, InputSet>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum VariableType {
|
||||
Argument,
|
||||
Internal,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LocalContext {
|
||||
pub signature: Signature,
|
||||
|
||||
pub input: InputSet,
|
||||
|
||||
pub substitutions: HashMap<Wildcard, String>,
|
||||
pub variables: HashMap<String, (TypeKind, VariableType)>,
|
||||
}
|
||||
|
||||
impl LocalContext {
|
||||
pub fn new(input: InputSet, original: &Intrinsic) -> LocalContext {
|
||||
LocalContext {
|
||||
signature: original.signature.clone(),
|
||||
input,
|
||||
substitutions: HashMap::new(),
|
||||
variables: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn provide_type_wildcard(&self, wildcard: &Wildcard) -> Result<TypeKind> {
|
||||
let err = || format!("wildcard {{{wildcard}}} not found");
|
||||
|
||||
let make_neon = |tuple_size| move |ty| TypeKind::make_vector(ty, false, tuple_size);
|
||||
let make_sve = |tuple_size| move |ty| TypeKind::make_vector(ty, true, tuple_size);
|
||||
|
||||
match wildcard {
|
||||
Wildcard::Type(idx) => self.input.typekind(*idx).ok_or_else(err),
|
||||
Wildcard::NEONType(idx, tuple_size) => self
|
||||
.input
|
||||
.typekind(*idx)
|
||||
.ok_or_else(err)
|
||||
.and_then(make_neon(*tuple_size)),
|
||||
Wildcard::SVEType(idx, tuple_size) => self
|
||||
.input
|
||||
.typekind(*idx)
|
||||
.ok_or_else(err)
|
||||
.and_then(make_sve(*tuple_size)),
|
||||
Wildcard::Predicate(idx) => self.input.typekind(*idx).map_or_else(
|
||||
|| {
|
||||
if idx.is_none() && self.input.types_len() == 1 {
|
||||
Err(err())
|
||||
} else {
|
||||
Err(format!(
|
||||
"there is no type at index {} to infer the predicate from",
|
||||
idx.unwrap_or(0)
|
||||
))
|
||||
}
|
||||
},
|
||||
|ref ty| TypeKind::make_predicate_from(ty),
|
||||
),
|
||||
Wildcard::MaxPredicate => self
|
||||
.input
|
||||
.iter()
|
||||
.filter_map(|arg| arg.typekind())
|
||||
.max_by(|x, y| {
|
||||
x.base_type()
|
||||
.and_then(|bt| bt.get_size().ok())
|
||||
.unwrap_or(0)
|
||||
.cmp(&y.base_type().and_then(|bt| bt.get_size().ok()).unwrap_or(0))
|
||||
})
|
||||
.map_or_else(
|
||||
|| Err("there are no types available to infer the predicate from".to_string()),
|
||||
TypeKind::make_predicate_from,
|
||||
),
|
||||
Wildcard::Scale(w, as_ty) => {
|
||||
let mut ty = self.provide_type_wildcard(w)?;
|
||||
if let Some(vty) = ty.vector_mut() {
|
||||
let base_ty = if let Some(w) = as_ty.wildcard() {
|
||||
*self.provide_type_wildcard(w)?.base_type().unwrap()
|
||||
} else {
|
||||
*as_ty.base_type().unwrap()
|
||||
};
|
||||
vty.cast_base_type_as(base_ty)
|
||||
}
|
||||
Ok(ty)
|
||||
}
|
||||
_ => Err(err()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn provide_substitution_wildcard(&self, wildcard: &Wildcard) -> Result<String> {
|
||||
let err = || Err(format!("wildcard {{{wildcard}}} not found"));
|
||||
|
||||
match wildcard {
|
||||
Wildcard::SizeLiteral(idx) => self.input.typekind(*idx)
|
||||
.map_or_else(err, |ty| Ok(ty.size_literal())),
|
||||
Wildcard::Size(idx) => self.input.typekind(*idx)
|
||||
.map_or_else(err, |ty| Ok(ty.size())),
|
||||
Wildcard::SizeMinusOne(idx) => self.input.typekind(*idx)
|
||||
.map_or_else(err, |ty| Ok((ty.size().parse::<i32>().unwrap()-1).to_string())),
|
||||
Wildcard::SizeInBytesLog2(idx) => self.input.typekind(*idx)
|
||||
.map_or_else(err, |ty| Ok(ty.size_in_bytes_log2())),
|
||||
Wildcard::NVariant if self.substitutions.get(wildcard).is_none() => Ok(String::new()),
|
||||
Wildcard::TypeKind(idx, opts) => {
|
||||
self.input.typekind(*idx)
|
||||
.map_or_else(err, |ty| {
|
||||
let literal = if let Some(opts) = opts {
|
||||
opts.contains(ty.base_type().map(|bt| *bt.kind()).ok_or_else(|| {
|
||||
format!("cannot retrieve a type literal out of {ty}")
|
||||
})?)
|
||||
.then(|| ty.type_kind())
|
||||
.unwrap_or_default()
|
||||
} else {
|
||||
ty.type_kind()
|
||||
};
|
||||
Ok(literal)
|
||||
})
|
||||
}
|
||||
Wildcard::PredicateForms(_) => self
|
||||
.input
|
||||
.iter()
|
||||
.find_map(|arg| {
|
||||
if let InputType::PredicateForm(pf) = arg {
|
||||
Some(pf.get_suffix().to_string())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.ok_or_else(|| unreachable!("attempting to render a predicate form wildcard, but no predicate form was compiled for it")),
|
||||
_ => self
|
||||
.substitutions
|
||||
.get(wildcard)
|
||||
.map_or_else(err, |s| Ok(s.clone())),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn make_assertion_from_constraint(&self, constraint: &Constraint) -> Result<Expression> {
|
||||
match constraint {
|
||||
Constraint::AnyI32 {
|
||||
variable,
|
||||
any_values,
|
||||
} => {
|
||||
let where_ex = any_values
|
||||
.iter()
|
||||
.map(|value| format!("{variable} == {value}"))
|
||||
.join(" || ");
|
||||
Ok(Expression::MacroCall("static_assert".to_string(), where_ex))
|
||||
}
|
||||
Constraint::RangeI32 {
|
||||
variable,
|
||||
range: SizeMatchable::Matched(range),
|
||||
} => Ok(Expression::MacroCall(
|
||||
"static_assert_range".to_string(),
|
||||
format!(
|
||||
"{variable}, {min}, {max}",
|
||||
min = range.start(),
|
||||
max = range.end()
|
||||
),
|
||||
)),
|
||||
Constraint::SVEMaxElems {
|
||||
variable,
|
||||
sve_max_elems_type: ty,
|
||||
}
|
||||
| Constraint::VecMaxElems {
|
||||
variable,
|
||||
vec_max_elems_type: ty,
|
||||
} => {
|
||||
if !self.input.is_empty() {
|
||||
let higher_limit = match constraint {
|
||||
Constraint::SVEMaxElems { .. } => SVE_VECTOR_MAX_SIZE,
|
||||
Constraint::VecMaxElems { .. } => VECTOR_REG_SIZE,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
let max = ty.base_type()
|
||||
.map(|ty| ty.get_size())
|
||||
.transpose()?
|
||||
.map_or_else(
|
||||
|| Err(format!("can't make an assertion out of constraint {self:?}: no valid type is present")),
|
||||
|bitsize| Ok(higher_limit / bitsize - 1))?;
|
||||
Ok(Expression::MacroCall(
|
||||
"static_assert_range".to_string(),
|
||||
format!("{variable}, 0, {max}"),
|
||||
))
|
||||
} else {
|
||||
Err(format!("can't make an assertion out of constraint {self:?}: no types are being used"))
|
||||
}
|
||||
}
|
||||
_ => unreachable!("constraints were not built successfully!"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn predicate_form(&self) -> Option<&PredicateForm> {
|
||||
self.input.iter().find_map(|arg| arg.predicate_form())
|
||||
}
|
||||
|
||||
pub fn n_variant_op(&self) -> Option<&WildString> {
|
||||
self.input.iter().find_map(|arg| arg.n_variant_op())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Context<'ctx> {
|
||||
pub local: &'ctx mut LocalContext,
|
||||
pub group: &'ctx mut GroupContext,
|
||||
pub global: &'ctx GlobalContext,
|
||||
}
|
||||
546
library/stdarch/crates/stdarch-gen2/src/expression.rs
Normal file
546
library/stdarch/crates/stdarch-gen2/src/expression.rs
Normal file
|
|
@ -0,0 +1,546 @@
|
|||
use itertools::Itertools;
|
||||
use lazy_static::lazy_static;
|
||||
use proc_macro2::{Literal, TokenStream};
|
||||
use quote::{format_ident, quote, ToTokens, TokenStreamExt};
|
||||
use regex::Regex;
|
||||
use serde::de::{self, MapAccess, Visitor};
|
||||
use serde::{Deserialize, Deserializer, Serialize};
|
||||
use std::fmt;
|
||||
use std::str::FromStr;
|
||||
|
||||
use crate::intrinsic::Intrinsic;
|
||||
use crate::{
|
||||
context::{self, Context, VariableType},
|
||||
intrinsic::{Argument, LLVMLink, StaticDefinition},
|
||||
matching::{MatchKindValues, MatchSizeValues},
|
||||
typekinds::{BaseType, BaseTypeKind, TypeKind},
|
||||
wildcards::Wildcard,
|
||||
wildstring::WildString,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
|
||||
pub enum IdentifierType {
|
||||
Variable,
|
||||
Symbol,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum LetVariant {
|
||||
Basic(WildString, Box<Expression>),
|
||||
WithType(WildString, TypeKind, Box<Expression>),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct FnCall(
|
||||
/// Function pointer
|
||||
pub Box<Expression>,
|
||||
/// Function arguments
|
||||
pub Vec<Expression>,
|
||||
/// Function turbofish arguments
|
||||
#[serde(default)]
|
||||
pub Vec<Expression>,
|
||||
);
|
||||
|
||||
impl FnCall {
|
||||
pub fn new_expression(fn_ptr: Expression, arguments: Vec<Expression>) -> Expression {
|
||||
FnCall(Box::new(fn_ptr), arguments, Vec::new()).into()
|
||||
}
|
||||
|
||||
pub fn is_llvm_link_call(&self, llvm_link_name: &String) -> bool {
|
||||
if let Expression::Identifier(fn_name, IdentifierType::Symbol) = self.0.as_ref() {
|
||||
&fn_name.to_string() == llvm_link_name
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
pub fn pre_build(&mut self, ctx: &mut Context) -> context::Result {
|
||||
self.0.pre_build(ctx)?;
|
||||
self.1
|
||||
.iter_mut()
|
||||
.chain(self.2.iter_mut())
|
||||
.try_for_each(|ex| ex.pre_build(ctx))
|
||||
}
|
||||
|
||||
pub fn build(&mut self, intrinsic: &Intrinsic, ctx: &mut Context) -> context::Result {
|
||||
self.0.build(intrinsic, ctx)?;
|
||||
self.1
|
||||
.iter_mut()
|
||||
.chain(self.2.iter_mut())
|
||||
.try_for_each(|ex| ex.build(intrinsic, ctx))
|
||||
}
|
||||
}
|
||||
|
||||
impl ToTokens for FnCall {
|
||||
fn to_tokens(&self, tokens: &mut TokenStream) {
|
||||
let FnCall(fn_ptr, arguments, turbofish) = self;
|
||||
|
||||
fn_ptr.to_tokens(tokens);
|
||||
|
||||
if !turbofish.is_empty() {
|
||||
tokens.append_all(quote! {::<#(#turbofish),*>});
|
||||
}
|
||||
|
||||
tokens.append_all(quote! { (#(#arguments),*) })
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(remote = "Self", deny_unknown_fields)]
|
||||
pub enum Expression {
|
||||
/// (Re)Defines a variable
|
||||
Let(LetVariant),
|
||||
/// Performs a variable assignment operation
|
||||
Assign(String, Box<Expression>),
|
||||
/// Performs a macro call
|
||||
MacroCall(String, String),
|
||||
/// Performs a function call
|
||||
FnCall(FnCall),
|
||||
/// Performs a method call. The following:
|
||||
/// `MethodCall: ["$object", "to_string", []]`
|
||||
/// is tokenized as:
|
||||
/// `object.to_string()`.
|
||||
MethodCall(Box<Expression>, String, Vec<Expression>),
|
||||
/// Symbol identifier name, prepend with a `$` to treat it as a scope variable
|
||||
/// which engages variable tracking and enables inference.
|
||||
/// E.g. `my_function_name` for a generic symbol or `$my_variable` for
|
||||
/// a variable.
|
||||
Identifier(WildString, IdentifierType),
|
||||
/// Constant signed integer number expression
|
||||
IntConstant(i32),
|
||||
/// Constant floating point number expression
|
||||
FloatConstant(f32),
|
||||
/// Constant boolean expression, either `true` or `false`
|
||||
BoolConstant(bool),
|
||||
/// Array expression
|
||||
Array(Vec<Expression>),
|
||||
|
||||
// complex expressions
|
||||
/// Makes an LLVM link.
|
||||
///
|
||||
/// It stores the link's function name in the wildcard `{llvm_link}`, for use in
|
||||
/// subsequent expressions.
|
||||
LLVMLink(LLVMLink),
|
||||
/// Casts the given expression to the specified (unchecked) type
|
||||
CastAs(Box<Expression>, String),
|
||||
/// Returns the LLVM `undef` symbol
|
||||
SvUndef,
|
||||
/// Multiplication
|
||||
Multiply(Box<Expression>, Box<Expression>),
|
||||
/// Converts the specified constant to the specified type's kind
|
||||
ConvertConst(TypeKind, i32),
|
||||
/// Yields the given type in the Rust representation
|
||||
Type(TypeKind),
|
||||
|
||||
MatchSize(TypeKind, MatchSizeValues<Box<Expression>>),
|
||||
MatchKind(TypeKind, MatchKindValues<Box<Expression>>),
|
||||
}
|
||||
|
||||
impl Expression {
|
||||
pub fn pre_build(&mut self, ctx: &mut Context) -> context::Result {
|
||||
match self {
|
||||
Self::FnCall(fn_call) => fn_call.pre_build(ctx),
|
||||
Self::MethodCall(cl_ptr_ex, _, arg_exs) => {
|
||||
cl_ptr_ex.pre_build(ctx)?;
|
||||
arg_exs.iter_mut().try_for_each(|ex| ex.pre_build(ctx))
|
||||
}
|
||||
Self::Let(LetVariant::Basic(_, ex) | LetVariant::WithType(_, _, ex)) => {
|
||||
ex.pre_build(ctx)
|
||||
}
|
||||
Self::CastAs(ex, _) => ex.pre_build(ctx),
|
||||
Self::Multiply(lhs, rhs) => {
|
||||
lhs.pre_build(ctx)?;
|
||||
rhs.pre_build(ctx)
|
||||
}
|
||||
Self::MatchSize(match_ty, values) => {
|
||||
*self = *values.get(match_ty, ctx.local)?.to_owned();
|
||||
self.pre_build(ctx)
|
||||
}
|
||||
Self::MatchKind(match_ty, values) => {
|
||||
*self = *values.get(match_ty, ctx.local)?.to_owned();
|
||||
self.pre_build(ctx)
|
||||
}
|
||||
_ => Ok(()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn build(&mut self, intrinsic: &Intrinsic, ctx: &mut Context) -> context::Result {
|
||||
match self {
|
||||
Self::LLVMLink(link) => link.build_and_save(ctx),
|
||||
Self::Identifier(identifier, id_type) => {
|
||||
identifier.build_acle(ctx.local)?;
|
||||
|
||||
if let IdentifierType::Variable = id_type {
|
||||
ctx.local
|
||||
.variables
|
||||
.get(&identifier.to_string())
|
||||
.map(|_| ())
|
||||
.ok_or_else(|| format!("invalid variable {identifier} being referenced"))
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
Self::FnCall(fn_call) => {
|
||||
fn_call.build(intrinsic, ctx)?;
|
||||
|
||||
if let Some(llvm_link_name) = ctx.local.substitutions.get(&Wildcard::LLVMLink) {
|
||||
if fn_call.is_llvm_link_call(llvm_link_name) {
|
||||
*self = intrinsic
|
||||
.llvm_link()
|
||||
.expect("got LLVMLink wildcard without a LLVM link in `compose`")
|
||||
.apply_conversions_to_call(fn_call.clone(), ctx.local)?
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Self::MethodCall(cl_ptr_ex, _, arg_exs) => {
|
||||
cl_ptr_ex.build(intrinsic, ctx)?;
|
||||
arg_exs
|
||||
.iter_mut()
|
||||
.try_for_each(|ex| ex.build(intrinsic, ctx))
|
||||
}
|
||||
Self::Let(variant) => {
|
||||
let (var_name, ex, ty) = match variant {
|
||||
LetVariant::Basic(var_name, ex) => (var_name, ex, None),
|
||||
LetVariant::WithType(var_name, ty, ex) => {
|
||||
if let Some(w) = ty.wildcard() {
|
||||
ty.populate_wildcard(ctx.local.provide_type_wildcard(w)?)?;
|
||||
}
|
||||
(var_name, ex, Some(ty.to_owned()))
|
||||
}
|
||||
};
|
||||
|
||||
var_name.build_acle(ctx.local)?;
|
||||
ctx.local.variables.insert(
|
||||
var_name.to_string(),
|
||||
(
|
||||
ty.unwrap_or_else(|| TypeKind::Custom("unknown".to_string())),
|
||||
VariableType::Internal,
|
||||
),
|
||||
);
|
||||
ex.build(intrinsic, ctx)
|
||||
}
|
||||
Self::CastAs(ex, _) => ex.build(intrinsic, ctx),
|
||||
Self::Multiply(lhs, rhs) => {
|
||||
lhs.build(intrinsic, ctx)?;
|
||||
rhs.build(intrinsic, ctx)
|
||||
}
|
||||
Self::ConvertConst(ty, num) => {
|
||||
if let Some(w) = ty.wildcard() {
|
||||
*ty = ctx.local.provide_type_wildcard(w)?
|
||||
}
|
||||
|
||||
if let Some(BaseType::Sized(BaseTypeKind::Float, _)) = ty.base() {
|
||||
*self = Expression::FloatConstant(*num as f32)
|
||||
} else {
|
||||
*self = Expression::IntConstant(*num)
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
Self::Type(ty) => {
|
||||
if let Some(w) = ty.wildcard() {
|
||||
*ty = ctx.local.provide_type_wildcard(w)?
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
_ => Ok(()),
|
||||
}
|
||||
}
|
||||
|
||||
/// True if the expression requires an `unsafe` context in a safe function.
|
||||
///
|
||||
/// The classification is somewhat fuzzy, based on actual usage (e.g. empirical function names)
|
||||
/// rather than a full parse. This is a reasonable approach because mistakes here will usually
|
||||
/// be caught at build time:
|
||||
///
|
||||
/// - Missing an `unsafe` is a build error.
|
||||
/// - An unnecessary `unsafe` is a warning, made into an error by the CI's `-D warnings`.
|
||||
///
|
||||
/// This **panics** if it encounters an expression that shouldn't appear in a safe function at
|
||||
/// all (such as `SvUndef`).
|
||||
pub fn requires_unsafe_wrapper(&self, ctx_fn: &str) -> bool {
|
||||
match self {
|
||||
// The call will need to be unsafe, but the declaration does not.
|
||||
Self::LLVMLink(..) => false,
|
||||
// Identifiers, literals and type names are never unsafe.
|
||||
Self::Identifier(..) => false,
|
||||
Self::IntConstant(..) => false,
|
||||
Self::FloatConstant(..) => false,
|
||||
Self::BoolConstant(..) => false,
|
||||
Self::Type(..) => false,
|
||||
Self::ConvertConst(..) => false,
|
||||
// Nested structures that aren't inherently unsafe, but could contain other expressions
|
||||
// that might be.
|
||||
Self::Assign(_var, exp) => exp.requires_unsafe_wrapper(ctx_fn),
|
||||
Self::Let(LetVariant::Basic(_, exp) | LetVariant::WithType(_, _, exp)) => {
|
||||
exp.requires_unsafe_wrapper(ctx_fn)
|
||||
}
|
||||
Self::Array(exps) => exps.iter().any(|exp| exp.requires_unsafe_wrapper(ctx_fn)),
|
||||
Self::Multiply(lhs, rhs) => {
|
||||
lhs.requires_unsafe_wrapper(ctx_fn) || rhs.requires_unsafe_wrapper(ctx_fn)
|
||||
}
|
||||
Self::CastAs(exp, _ty) => exp.requires_unsafe_wrapper(ctx_fn),
|
||||
// Functions and macros can be unsafe, but can also contain other expressions.
|
||||
Self::FnCall(FnCall(fn_exp, args, turbo_args)) => {
|
||||
let fn_name = fn_exp.to_string();
|
||||
fn_exp.requires_unsafe_wrapper(ctx_fn)
|
||||
|| fn_name.starts_with("_sv")
|
||||
|| fn_name.starts_with("simd_")
|
||||
|| fn_name.ends_with("transmute")
|
||||
|| args.iter().any(|exp| exp.requires_unsafe_wrapper(ctx_fn))
|
||||
|| turbo_args
|
||||
.iter()
|
||||
.any(|exp| exp.requires_unsafe_wrapper(ctx_fn))
|
||||
}
|
||||
Self::MethodCall(exp, fn_name, args) => match fn_name.as_str() {
|
||||
// `as_signed` and `as_unsigned` are unsafe because they're trait methods with
|
||||
// target features to allow use on feature-dependent types (such as SVE vectors).
|
||||
// We can safely wrap them here.
|
||||
"as_signed" => true,
|
||||
"as_unsigned" => true,
|
||||
_ => {
|
||||
exp.requires_unsafe_wrapper(ctx_fn)
|
||||
|| args.iter().any(|exp| exp.requires_unsafe_wrapper(ctx_fn))
|
||||
}
|
||||
},
|
||||
// We only use macros to check const generics (using static assertions).
|
||||
Self::MacroCall(_name, _args) => false,
|
||||
// Materialising uninitialised values is always unsafe, and we avoid it in safe
|
||||
// functions.
|
||||
Self::SvUndef => panic!("Refusing to wrap unsafe SvUndef in safe function '{ctx_fn}'."),
|
||||
// Variants that aren't tokenised. We shouldn't encounter these here.
|
||||
Self::MatchKind(..) => {
|
||||
unimplemented!("The unsafety of {self:?} cannot be determined in '{ctx_fn}'.")
|
||||
}
|
||||
Self::MatchSize(..) => {
|
||||
unimplemented!("The unsafety of {self:?} cannot be determined in '{ctx_fn}'.")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for Expression {
|
||||
type Err = String;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
lazy_static! {
|
||||
static ref MACRO_RE: Regex =
|
||||
Regex::new(r"^(?P<name>[\w\d_]+)!\((?P<ex>.*?)\);?$").unwrap();
|
||||
}
|
||||
|
||||
if s == "SvUndef" {
|
||||
Ok(Expression::SvUndef)
|
||||
} else if MACRO_RE.is_match(s) {
|
||||
let c = MACRO_RE.captures(s).unwrap();
|
||||
let ex = c["ex"].to_string();
|
||||
let _: TokenStream = ex
|
||||
.parse()
|
||||
.map_err(|e| format!("could not parse macro call expression: {e:#?}"))?;
|
||||
Ok(Expression::MacroCall(c["name"].to_string(), ex))
|
||||
} else {
|
||||
let (s, id_type) = if let Some(varname) = s.strip_prefix('$') {
|
||||
(varname, IdentifierType::Variable)
|
||||
} else {
|
||||
(s, IdentifierType::Symbol)
|
||||
};
|
||||
let identifier = s.trim().parse()?;
|
||||
Ok(Expression::Identifier(identifier, id_type))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<FnCall> for Expression {
|
||||
fn from(fn_call: FnCall) -> Self {
|
||||
Expression::FnCall(fn_call)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<WildString> for Expression {
|
||||
fn from(ws: WildString) -> Self {
|
||||
Expression::Identifier(ws, IdentifierType::Symbol)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&Argument> for Expression {
|
||||
fn from(a: &Argument) -> Self {
|
||||
Expression::Identifier(a.name.to_owned(), IdentifierType::Variable)
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&StaticDefinition> for Expression {
|
||||
type Error = String;
|
||||
|
||||
fn try_from(sd: &StaticDefinition) -> Result<Self, Self::Error> {
|
||||
match sd {
|
||||
StaticDefinition::Constant(imm) => Ok(imm.into()),
|
||||
StaticDefinition::Generic(t) => t.parse(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for Expression {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Self::Identifier(identifier, kind) => {
|
||||
write!(
|
||||
f,
|
||||
"{}{identifier}",
|
||||
matches!(kind, IdentifierType::Variable)
|
||||
.then_some("$")
|
||||
.unwrap_or_default()
|
||||
)
|
||||
}
|
||||
Self::MacroCall(name, expression) => {
|
||||
write!(f, "{name}!({expression})")
|
||||
}
|
||||
_ => Err(fmt::Error),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ToTokens for Expression {
|
||||
fn to_tokens(&self, tokens: &mut TokenStream) {
|
||||
match self {
|
||||
Self::Let(LetVariant::Basic(var_name, exp)) => {
|
||||
let var_ident = format_ident!("{}", var_name.to_string());
|
||||
tokens.append_all(quote! { let #var_ident = #exp })
|
||||
}
|
||||
Self::Let(LetVariant::WithType(var_name, ty, exp)) => {
|
||||
let var_ident = format_ident!("{}", var_name.to_string());
|
||||
tokens.append_all(quote! { let #var_ident: #ty = #exp })
|
||||
}
|
||||
Self::Assign(var_name, exp) => {
|
||||
let var_ident = format_ident!("{}", var_name);
|
||||
tokens.append_all(quote! { #var_ident = #exp })
|
||||
}
|
||||
Self::MacroCall(name, ex) => {
|
||||
let name = format_ident!("{name}");
|
||||
let ex: TokenStream = ex.parse().unwrap();
|
||||
tokens.append_all(quote! { #name!(#ex) })
|
||||
}
|
||||
Self::FnCall(fn_call) => fn_call.to_tokens(tokens),
|
||||
Self::MethodCall(exp, fn_name, args) => {
|
||||
let fn_ident = format_ident!("{}", fn_name);
|
||||
tokens.append_all(quote! { #exp.#fn_ident(#(#args),*) })
|
||||
}
|
||||
Self::Identifier(identifier, _) => {
|
||||
assert!(
|
||||
!identifier.has_wildcards(),
|
||||
"expression {self:#?} was not built before calling to_tokens"
|
||||
);
|
||||
identifier
|
||||
.to_string()
|
||||
.parse::<TokenStream>()
|
||||
.expect("invalid syntax")
|
||||
.to_tokens(tokens);
|
||||
}
|
||||
Self::IntConstant(n) => tokens.append(Literal::i32_unsuffixed(*n)),
|
||||
Self::FloatConstant(n) => tokens.append(Literal::f32_unsuffixed(*n)),
|
||||
Self::BoolConstant(true) => tokens.append(format_ident!("true")),
|
||||
Self::BoolConstant(false) => tokens.append(format_ident!("false")),
|
||||
Self::Array(vec) => tokens.append_all(quote! { [ #(#vec),* ] }),
|
||||
Self::LLVMLink(link) => link.to_tokens(tokens),
|
||||
Self::CastAs(ex, ty) => {
|
||||
let ty: TokenStream = ty.parse().expect("invalid syntax");
|
||||
tokens.append_all(quote! { #ex as #ty })
|
||||
}
|
||||
Self::SvUndef => tokens.append_all(quote! { simd_reinterpret(()) }),
|
||||
Self::Multiply(lhs, rhs) => tokens.append_all(quote! { #lhs * #rhs }),
|
||||
Self::Type(ty) => ty.to_tokens(tokens),
|
||||
_ => unreachable!("{self:?} cannot be converted to tokens."),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for Expression {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
match self {
|
||||
Self::IntConstant(v) => serializer.serialize_i32(*v),
|
||||
Self::FloatConstant(v) => serializer.serialize_f32(*v),
|
||||
Self::BoolConstant(v) => serializer.serialize_bool(*v),
|
||||
Self::Identifier(..) => serializer.serialize_str(&self.to_string()),
|
||||
Self::MacroCall(..) => serializer.serialize_str(&self.to_string()),
|
||||
_ => Expression::serialize(self, serializer),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for Expression {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
struct CustomExpressionVisitor;
|
||||
|
||||
impl<'de> Visitor<'de> for CustomExpressionVisitor {
|
||||
type Value = Expression;
|
||||
|
||||
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
|
||||
formatter.write_str("integer, float, boolean, string or map")
|
||||
}
|
||||
|
||||
fn visit_bool<E>(self, v: bool) -> Result<Self::Value, E>
|
||||
where
|
||||
E: de::Error,
|
||||
{
|
||||
Ok(Expression::BoolConstant(v))
|
||||
}
|
||||
|
||||
fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
|
||||
where
|
||||
E: de::Error,
|
||||
{
|
||||
Ok(Expression::IntConstant(v as i32))
|
||||
}
|
||||
|
||||
fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
|
||||
where
|
||||
E: de::Error,
|
||||
{
|
||||
Ok(Expression::IntConstant(v as i32))
|
||||
}
|
||||
|
||||
fn visit_f64<E>(self, v: f64) -> Result<Self::Value, E>
|
||||
where
|
||||
E: de::Error,
|
||||
{
|
||||
Ok(Expression::FloatConstant(v as f32))
|
||||
}
|
||||
|
||||
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
|
||||
where
|
||||
E: de::Error,
|
||||
{
|
||||
FromStr::from_str(value).map_err(de::Error::custom)
|
||||
}
|
||||
|
||||
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
|
||||
where
|
||||
A: de::SeqAccess<'de>,
|
||||
{
|
||||
let arr = std::iter::from_fn(|| seq.next_element::<Self::Value>().transpose())
|
||||
.try_collect()?;
|
||||
Ok(Expression::Array(arr))
|
||||
}
|
||||
|
||||
fn visit_map<M>(self, map: M) -> Result<Expression, M::Error>
|
||||
where
|
||||
M: MapAccess<'de>,
|
||||
{
|
||||
// `MapAccessDeserializer` is a wrapper that turns a `MapAccess`
|
||||
// into a `Deserializer`, allowing it to be used as the input to T's
|
||||
// `Deserialize` implementation. T then deserializes itself using
|
||||
// the entries from the map visitor.
|
||||
Expression::deserialize(de::value::MapAccessDeserializer::new(map))
|
||||
}
|
||||
}
|
||||
|
||||
deserializer.deserialize_any(CustomExpressionVisitor)
|
||||
}
|
||||
}
|
||||
432
library/stdarch/crates/stdarch-gen2/src/input.rs
Normal file
432
library/stdarch/crates/stdarch-gen2/src/input.rs
Normal file
|
|
@ -0,0 +1,432 @@
|
|||
use itertools::Itertools;
|
||||
use serde::{de, Deserialize, Deserializer, Serialize};
|
||||
|
||||
use crate::{
|
||||
context::{self, GlobalContext},
|
||||
intrinsic::Intrinsic,
|
||||
predicate_forms::{PredicateForm, PredicationMask, PredicationMethods},
|
||||
typekinds::TypeKind,
|
||||
wildstring::WildString,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum InputType {
|
||||
/// PredicateForm variant argument
|
||||
#[serde(skip)] // Predicate forms have their own dedicated deserialization field. Skip.
|
||||
PredicateForm(PredicateForm),
|
||||
/// Operand from which to generate an N variant
|
||||
#[serde(skip)]
|
||||
NVariantOp(Option<WildString>),
|
||||
/// TypeKind variant argument
|
||||
Type(TypeKind),
|
||||
}
|
||||
|
||||
impl InputType {
|
||||
/// Optionally unwraps as a PredicateForm.
|
||||
pub fn predicate_form(&self) -> Option<&PredicateForm> {
|
||||
match self {
|
||||
InputType::PredicateForm(pf) => Some(pf),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Optionally unwraps as a mutable PredicateForm
|
||||
pub fn predicate_form_mut(&mut self) -> Option<&mut PredicateForm> {
|
||||
match self {
|
||||
InputType::PredicateForm(pf) => Some(pf),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Optionally unwraps as a TypeKind.
|
||||
pub fn typekind(&self) -> Option<&TypeKind> {
|
||||
match self {
|
||||
InputType::Type(ty) => Some(ty),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Optionally unwraps as a NVariantOp
|
||||
pub fn n_variant_op(&self) -> Option<&WildString> {
|
||||
match self {
|
||||
InputType::NVariantOp(Some(op)) => Some(op),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialOrd for InputType {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
|
||||
Some(self.cmp(other))
|
||||
}
|
||||
}
|
||||
|
||||
impl Ord for InputType {
|
||||
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
|
||||
use std::cmp::Ordering::*;
|
||||
|
||||
match (self, other) {
|
||||
(InputType::PredicateForm(pf1), InputType::PredicateForm(pf2)) => pf1.cmp(pf2),
|
||||
(InputType::Type(ty1), InputType::Type(ty2)) => ty1.cmp(ty2),
|
||||
|
||||
(InputType::NVariantOp(None), InputType::NVariantOp(Some(..))) => Less,
|
||||
(InputType::NVariantOp(Some(..)), InputType::NVariantOp(None)) => Greater,
|
||||
(InputType::NVariantOp(_), InputType::NVariantOp(_)) => Equal,
|
||||
|
||||
(InputType::Type(..), InputType::PredicateForm(..)) => Less,
|
||||
(InputType::PredicateForm(..), InputType::Type(..)) => Greater,
|
||||
|
||||
(InputType::Type(..), InputType::NVariantOp(..)) => Less,
|
||||
(InputType::NVariantOp(..), InputType::Type(..)) => Greater,
|
||||
|
||||
(InputType::PredicateForm(..), InputType::NVariantOp(..)) => Less,
|
||||
(InputType::NVariantOp(..), InputType::PredicateForm(..)) => Greater,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mod many_or_one {
|
||||
use serde::{de::Deserializer, ser::Serializer, Deserialize, Serialize};
|
||||
|
||||
pub fn serialize<T, S>(vec: &Vec<T>, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
T: Serialize,
|
||||
S: Serializer,
|
||||
{
|
||||
if vec.len() == 1 {
|
||||
vec.first().unwrap().serialize(serializer)
|
||||
} else {
|
||||
vec.serialize(serializer)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn deserialize<'de, T, D>(deserializer: D) -> Result<Vec<T>, D::Error>
|
||||
where
|
||||
T: Deserialize<'de>,
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
enum ManyOrOne<T> {
|
||||
Many(Vec<T>),
|
||||
One(T),
|
||||
}
|
||||
|
||||
match ManyOrOne::deserialize(deserializer)? {
|
||||
ManyOrOne::Many(vec) => Ok(vec),
|
||||
ManyOrOne::One(val) => Ok(vec![val]),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
|
||||
pub struct InputSet(#[serde(with = "many_or_one")] Vec<InputType>);
|
||||
|
||||
impl InputSet {
|
||||
pub fn get(&self, idx: usize) -> Option<&InputType> {
|
||||
self.0.get(idx)
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.0.is_empty()
|
||||
}
|
||||
|
||||
pub fn iter(&self) -> impl Iterator<Item = &InputType> + '_ {
|
||||
self.0.iter()
|
||||
}
|
||||
|
||||
pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut InputType> + '_ {
|
||||
self.0.iter_mut()
|
||||
}
|
||||
|
||||
pub fn into_iter(self) -> impl Iterator<Item = InputType> + Clone {
|
||||
self.0.into_iter()
|
||||
}
|
||||
|
||||
pub fn types_len(&self) -> usize {
|
||||
self.iter().filter_map(|arg| arg.typekind()).count()
|
||||
}
|
||||
|
||||
pub fn typekind(&self, idx: Option<usize>) -> Option<TypeKind> {
|
||||
let types_len = self.types_len();
|
||||
self.get(idx.unwrap_or(0)).and_then(move |arg: &InputType| {
|
||||
if (idx.is_none() && types_len != 1) || (idx.is_some() && types_len == 1) {
|
||||
None
|
||||
} else {
|
||||
arg.typekind().cloned()
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct InputSetEntry(#[serde(with = "many_or_one")] Vec<InputSet>);
|
||||
|
||||
impl InputSetEntry {
|
||||
pub fn new(input: Vec<InputSet>) -> Self {
|
||||
Self(input)
|
||||
}
|
||||
|
||||
pub fn get(&self, idx: usize) -> Option<&InputSet> {
|
||||
self.0.get(idx)
|
||||
}
|
||||
}
|
||||
|
||||
fn validate_types<'de, D>(deserializer: D) -> Result<Vec<InputSetEntry>, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let v: Vec<InputSetEntry> = Vec::deserialize(deserializer)?;
|
||||
|
||||
let mut it = v.iter();
|
||||
if let Some(first) = it.next() {
|
||||
it.try_fold(first, |last, cur| {
|
||||
if last.0.len() == cur.0.len() {
|
||||
Ok(cur)
|
||||
} else {
|
||||
Err("the length of the InputSets and the product lists must match".to_string())
|
||||
}
|
||||
})
|
||||
.map_err(de::Error::custom)?;
|
||||
}
|
||||
|
||||
Ok(v)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct IntrinsicInput {
|
||||
#[serde(default)]
|
||||
#[serde(deserialize_with = "validate_types")]
|
||||
pub types: Vec<InputSetEntry>,
|
||||
|
||||
#[serde(flatten)]
|
||||
pub predication_methods: PredicationMethods,
|
||||
|
||||
/// Generates a _n variant where the specified operand is a primitive type
|
||||
/// that requires conversion to an SVE one. The `{_n}` wildcard is required
|
||||
/// in the intrinsic's name, otherwise an error will be thrown.
|
||||
#[serde(default)]
|
||||
pub n_variant_op: WildString,
|
||||
}
|
||||
|
||||
impl IntrinsicInput {
|
||||
/// Extracts all the possible variants as an iterator.
|
||||
pub fn variants(
|
||||
&self,
|
||||
intrinsic: &Intrinsic,
|
||||
) -> context::Result<impl Iterator<Item = InputSet> + '_> {
|
||||
let mut top_product = vec![];
|
||||
|
||||
if !self.types.is_empty() {
|
||||
top_product.push(
|
||||
self.types
|
||||
.iter()
|
||||
.flat_map(|ty_in| {
|
||||
ty_in
|
||||
.0
|
||||
.iter()
|
||||
.map(|v| v.clone().into_iter())
|
||||
.multi_cartesian_product()
|
||||
})
|
||||
.collect_vec(),
|
||||
)
|
||||
}
|
||||
|
||||
if let Ok(mask) = PredicationMask::try_from(&intrinsic.signature.name) {
|
||||
top_product.push(
|
||||
PredicateForm::compile_list(&mask, &self.predication_methods)?
|
||||
.into_iter()
|
||||
.map(|pf| vec![InputType::PredicateForm(pf)])
|
||||
.collect_vec(),
|
||||
)
|
||||
}
|
||||
|
||||
if !self.n_variant_op.is_empty() {
|
||||
top_product.push(vec![
|
||||
vec![InputType::NVariantOp(None)],
|
||||
vec![InputType::NVariantOp(Some(self.n_variant_op.to_owned()))],
|
||||
])
|
||||
}
|
||||
|
||||
let it = top_product
|
||||
.into_iter()
|
||||
.map(|v| v.into_iter())
|
||||
.multi_cartesian_product()
|
||||
.map(|set| InputSet(set.into_iter().flatten().collect_vec()));
|
||||
Ok(it)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GeneratorInput {
|
||||
#[serde(flatten)]
|
||||
pub ctx: GlobalContext,
|
||||
pub intrinsics: Vec<Intrinsic>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{
|
||||
input::*,
|
||||
predicate_forms::{DontCareMethod, ZeroingMethod},
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn test_empty() {
|
||||
let str = r#"types: []"#;
|
||||
let input: IntrinsicInput = serde_yaml::from_str(str).expect("failed to parse");
|
||||
let mut variants = input.variants(&Intrinsic::default()).unwrap().into_iter();
|
||||
assert_eq!(variants.next(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_product() {
|
||||
let str = r#"types:
|
||||
- [f64, f32]
|
||||
- [i64, [f64, f32]]
|
||||
"#;
|
||||
let input: IntrinsicInput = serde_yaml::from_str(str).expect("failed to parse");
|
||||
let mut intrinsic = Intrinsic::default();
|
||||
intrinsic.signature.name = "test_intrinsic{_mx}".parse().unwrap();
|
||||
let mut variants = input.variants(&intrinsic).unwrap().into_iter();
|
||||
assert_eq!(
|
||||
variants.next(),
|
||||
Some(InputSet(vec![
|
||||
InputType::Type("f64".parse().unwrap()),
|
||||
InputType::Type("f32".parse().unwrap()),
|
||||
InputType::PredicateForm(PredicateForm::Merging),
|
||||
]))
|
||||
);
|
||||
assert_eq!(
|
||||
variants.next(),
|
||||
Some(InputSet(vec![
|
||||
InputType::Type("f64".parse().unwrap()),
|
||||
InputType::Type("f32".parse().unwrap()),
|
||||
InputType::PredicateForm(PredicateForm::DontCare(DontCareMethod::AsMerging)),
|
||||
]))
|
||||
);
|
||||
assert_eq!(
|
||||
variants.next(),
|
||||
Some(InputSet(vec![
|
||||
InputType::Type("i64".parse().unwrap()),
|
||||
InputType::Type("f64".parse().unwrap()),
|
||||
InputType::PredicateForm(PredicateForm::Merging),
|
||||
]))
|
||||
);
|
||||
assert_eq!(
|
||||
variants.next(),
|
||||
Some(InputSet(vec![
|
||||
InputType::Type("i64".parse().unwrap()),
|
||||
InputType::Type("f64".parse().unwrap()),
|
||||
InputType::PredicateForm(PredicateForm::DontCare(DontCareMethod::AsMerging)),
|
||||
]))
|
||||
);
|
||||
assert_eq!(
|
||||
variants.next(),
|
||||
Some(InputSet(vec![
|
||||
InputType::Type("i64".parse().unwrap()),
|
||||
InputType::Type("f32".parse().unwrap()),
|
||||
InputType::PredicateForm(PredicateForm::Merging),
|
||||
]))
|
||||
);
|
||||
assert_eq!(
|
||||
variants.next(),
|
||||
Some(InputSet(vec![
|
||||
InputType::Type("i64".parse().unwrap()),
|
||||
InputType::Type("f32".parse().unwrap()),
|
||||
InputType::PredicateForm(PredicateForm::DontCare(DontCareMethod::AsMerging)),
|
||||
])),
|
||||
);
|
||||
assert_eq!(variants.next(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_n_variant() {
|
||||
let str = r#"types:
|
||||
- [f64, f32]
|
||||
n_variant_op: op2
|
||||
"#;
|
||||
let input: IntrinsicInput = serde_yaml::from_str(str).expect("failed to parse");
|
||||
let mut variants = input.variants(&Intrinsic::default()).unwrap().into_iter();
|
||||
assert_eq!(
|
||||
variants.next(),
|
||||
Some(InputSet(vec![
|
||||
InputType::Type("f64".parse().unwrap()),
|
||||
InputType::Type("f32".parse().unwrap()),
|
||||
InputType::NVariantOp(None),
|
||||
]))
|
||||
);
|
||||
assert_eq!(
|
||||
variants.next(),
|
||||
Some(InputSet(vec![
|
||||
InputType::Type("f64".parse().unwrap()),
|
||||
InputType::Type("f32".parse().unwrap()),
|
||||
InputType::NVariantOp(Some("op2".parse().unwrap())),
|
||||
]))
|
||||
);
|
||||
assert_eq!(variants.next(), None)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_length() {
|
||||
let str = r#"types: [i32, [[u64], [u32]]]"#;
|
||||
serde_yaml::from_str::<IntrinsicInput>(str).expect_err("failure expected");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_predication() {
|
||||
let str = "types: []";
|
||||
let input: IntrinsicInput = serde_yaml::from_str(str).expect("failed to parse");
|
||||
let mut intrinsic = Intrinsic::default();
|
||||
intrinsic.signature.name = "test_intrinsic{_mxz}".parse().unwrap();
|
||||
input
|
||||
.variants(&intrinsic)
|
||||
.map(|v| v.collect_vec())
|
||||
.expect_err("failure expected");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_predication_mask() {
|
||||
"test_intrinsic{_mxy}"
|
||||
.parse::<WildString>()
|
||||
.expect_err("failure expected");
|
||||
"test_intrinsic{_}"
|
||||
.parse::<WildString>()
|
||||
.expect_err("failure expected");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_zeroing_predication() {
|
||||
let str = r#"types: [i64]
|
||||
zeroing_method: { drop: inactive }"#;
|
||||
let input: IntrinsicInput = serde_yaml::from_str(str).expect("failed to parse");
|
||||
let mut intrinsic = Intrinsic::default();
|
||||
intrinsic.signature.name = "test_intrinsic{_mxz}".parse().unwrap();
|
||||
let mut variants = input.variants(&intrinsic).unwrap();
|
||||
assert_eq!(
|
||||
variants.next(),
|
||||
Some(InputSet(vec![
|
||||
InputType::Type("i64".parse().unwrap()),
|
||||
InputType::PredicateForm(PredicateForm::Merging),
|
||||
]))
|
||||
);
|
||||
assert_eq!(
|
||||
variants.next(),
|
||||
Some(InputSet(vec![
|
||||
InputType::Type("i64".parse().unwrap()),
|
||||
InputType::PredicateForm(PredicateForm::DontCare(DontCareMethod::AsZeroing)),
|
||||
]))
|
||||
);
|
||||
assert_eq!(
|
||||
variants.next(),
|
||||
Some(InputSet(vec![
|
||||
InputType::Type("i64".parse().unwrap()),
|
||||
InputType::PredicateForm(PredicateForm::Zeroing(ZeroingMethod::Drop {
|
||||
drop: "inactive".parse().unwrap()
|
||||
})),
|
||||
]))
|
||||
);
|
||||
assert_eq!(variants.next(), None)
|
||||
}
|
||||
}
|
||||
1498
library/stdarch/crates/stdarch-gen2/src/intrinsic.rs
Normal file
1498
library/stdarch/crates/stdarch-gen2/src/intrinsic.rs
Normal file
File diff suppressed because it is too large
Load diff
818
library/stdarch/crates/stdarch-gen2/src/load_store_tests.rs
Normal file
818
library/stdarch/crates/stdarch-gen2/src/load_store_tests.rs
Normal file
|
|
@ -0,0 +1,818 @@
|
|||
use std::fs::File;
|
||||
use std::io::Write;
|
||||
use std::path::PathBuf;
|
||||
use std::str::FromStr;
|
||||
|
||||
use crate::format_code;
|
||||
use crate::input::InputType;
|
||||
use crate::intrinsic::Intrinsic;
|
||||
use crate::typekinds::BaseType;
|
||||
use crate::typekinds::{ToRepr, TypeKind};
|
||||
|
||||
use itertools::Itertools;
|
||||
use lazy_static::lazy_static;
|
||||
use proc_macro2::TokenStream;
|
||||
use quote::{format_ident, quote};
|
||||
|
||||
// Number of vectors in our buffers - the maximum tuple size, 4, plus 1 as we set the vnum
|
||||
// argument to 1.
|
||||
const NUM_VECS: usize = 5;
|
||||
// The maximum vector length (in bits)
|
||||
const VL_MAX_BITS: usize = 2048;
|
||||
// The maximum vector length (in bytes)
|
||||
const VL_MAX_BYTES: usize = VL_MAX_BITS / 8;
|
||||
// The maximum number of elements in each vector type
|
||||
const LEN_F32: usize = VL_MAX_BYTES / core::mem::size_of::<f32>();
|
||||
const LEN_F64: usize = VL_MAX_BYTES / core::mem::size_of::<f64>();
|
||||
const LEN_I8: usize = VL_MAX_BYTES / core::mem::size_of::<i8>();
|
||||
const LEN_I16: usize = VL_MAX_BYTES / core::mem::size_of::<i16>();
|
||||
const LEN_I32: usize = VL_MAX_BYTES / core::mem::size_of::<i32>();
|
||||
const LEN_I64: usize = VL_MAX_BYTES / core::mem::size_of::<i64>();
|
||||
const LEN_U8: usize = VL_MAX_BYTES / core::mem::size_of::<u8>();
|
||||
const LEN_U16: usize = VL_MAX_BYTES / core::mem::size_of::<u16>();
|
||||
const LEN_U32: usize = VL_MAX_BYTES / core::mem::size_of::<u32>();
|
||||
const LEN_U64: usize = VL_MAX_BYTES / core::mem::size_of::<u64>();
|
||||
|
||||
/// `load_intrinsics` and `store_intrinsics` is a vector of intrinsics
|
||||
/// variants, while `out_path` is a file to write to.
|
||||
pub fn generate_load_store_tests(
|
||||
load_intrinsics: Vec<Intrinsic>,
|
||||
store_intrinsics: Vec<Intrinsic>,
|
||||
out_path: Option<&PathBuf>,
|
||||
) -> Result<(), String> {
|
||||
let output = match out_path {
|
||||
Some(out) => {
|
||||
Box::new(File::create(out).map_err(|e| format!("couldn't create tests file: {e}"))?)
|
||||
as Box<dyn Write>
|
||||
}
|
||||
None => Box::new(std::io::stdout()) as Box<dyn Write>,
|
||||
};
|
||||
let mut used_stores = vec![false; store_intrinsics.len()];
|
||||
let tests: Vec<_> = load_intrinsics
|
||||
.iter()
|
||||
.map(|load| {
|
||||
let store_candidate = load
|
||||
.signature
|
||||
.fn_name()
|
||||
.to_string()
|
||||
.replace("svld1s", "svst1")
|
||||
.replace("svld1u", "svst1")
|
||||
.replace("svldnt1s", "svstnt1")
|
||||
.replace("svldnt1u", "svstnt1")
|
||||
.replace("svld", "svst")
|
||||
.replace("gather", "scatter");
|
||||
|
||||
let store_index = store_intrinsics
|
||||
.iter()
|
||||
.position(|i| i.signature.fn_name().to_string() == store_candidate);
|
||||
if let Some(i) = store_index {
|
||||
used_stores[i] = true;
|
||||
}
|
||||
|
||||
generate_single_test(
|
||||
load.clone(),
|
||||
store_index.map(|i| store_intrinsics[i].clone()),
|
||||
)
|
||||
})
|
||||
.try_collect()?;
|
||||
|
||||
assert!(used_stores.into_iter().all(|b| b), "Not all store tests have been paired with a load. Consider generating specifc store-only tests");
|
||||
|
||||
let preamble =
|
||||
TokenStream::from_str(&PREAMBLE).map_err(|e| format!("Preamble is invalid: {e}"))?;
|
||||
// Only output manual tests for the SVE set
|
||||
let manual_tests = match &load_intrinsics[0].target_features[..] {
|
||||
[s] if s == "sve" => TokenStream::from_str(&MANUAL_TESTS)
|
||||
.map_err(|e| format!("Manual tests are invalid: {e}"))?,
|
||||
_ => quote!(),
|
||||
};
|
||||
format_code(
|
||||
output,
|
||||
format!(
|
||||
"// This code is automatically generated. DO NOT MODIFY.
|
||||
//
|
||||
// Instead, modify `crates/stdarch-gen2/spec/sve` and run the following command to re-generate this
|
||||
// file:
|
||||
//
|
||||
// ```
|
||||
// cargo run --bin=stdarch-gen2 -- crates/stdarch-gen2/spec
|
||||
// ```
|
||||
{}",
|
||||
quote! { #preamble #(#tests)* #manual_tests }
|
||||
),
|
||||
)
|
||||
.map_err(|e| format!("couldn't write tests: {e}"))
|
||||
}
|
||||
|
||||
/// A test looks like this:
|
||||
/// ```
|
||||
/// let data = [scalable vector];
|
||||
///
|
||||
/// let mut storage = [0; N];
|
||||
///
|
||||
/// store_intrinsic([true_predicate], storage.as_mut_ptr(), data);
|
||||
/// [test contents of storage]
|
||||
///
|
||||
/// let loaded == load_intrinsic([true_predicate], storage.as_ptr())
|
||||
/// assert!(loaded == data);
|
||||
/// ```
|
||||
/// We intialise our data such that the value stored matches the index it's stored to.
|
||||
/// By doing this we can validate scatters by checking that each value in the storage
|
||||
/// array is either 0 or the same as its index.
|
||||
fn generate_single_test(
|
||||
load: Intrinsic,
|
||||
store: Option<Intrinsic>,
|
||||
) -> Result<proc_macro2::TokenStream, String> {
|
||||
let chars = LdIntrCharacteristics::new(&load)?;
|
||||
let fn_name = load.signature.fn_name().to_string();
|
||||
|
||||
if let Some(ty) = &chars.gather_bases_type {
|
||||
if ty.base_type().unwrap().get_size() == Ok(32)
|
||||
&& chars.gather_index_type.is_none()
|
||||
&& chars.gather_offset_type.is_none()
|
||||
{
|
||||
// We lack a way to ensure data is in the bottom 32 bits of the address space
|
||||
println!("Skipping test for {fn_name}");
|
||||
return Ok(quote!());
|
||||
}
|
||||
}
|
||||
|
||||
if fn_name.starts_with("svldff1") && fn_name.contains("gather") {
|
||||
// TODO: We can remove this check when first-faulting gathers are fixed in CI's QEMU
|
||||
// https://gitlab.com/qemu-project/qemu/-/issues/1612
|
||||
println!("Skipping test for {fn_name}");
|
||||
return Ok(quote!());
|
||||
}
|
||||
|
||||
let fn_ident = format_ident!("{fn_name}");
|
||||
let test_name = format_ident!(
|
||||
"test_{fn_name}{}",
|
||||
if let Some(ref store) = store {
|
||||
format!("_with_{}", store.signature.fn_name())
|
||||
} else {
|
||||
String::new()
|
||||
}
|
||||
);
|
||||
|
||||
let load_type = &chars.load_type;
|
||||
let acle_type = load_type.acle_notation_repr();
|
||||
|
||||
// If there's no return type, fallback to the load type for things that depend on it
|
||||
let ret_type = &load
|
||||
.signature
|
||||
.return_type
|
||||
.as_ref()
|
||||
.and_then(TypeKind::base_type)
|
||||
.unwrap_or(load_type);
|
||||
|
||||
let pred_fn = format_ident!("svptrue_b{}", load_type.size());
|
||||
|
||||
let load_type_caps = load_type.rust_repr().to_uppercase();
|
||||
let data_array = format_ident!("{load_type_caps}_DATA");
|
||||
|
||||
let size_fn = format_ident!("svcnt{}", ret_type.size_literal());
|
||||
|
||||
let rust_ret_type = ret_type.rust_repr();
|
||||
let assert_fn = format_ident!("assert_vector_matches_{rust_ret_type}");
|
||||
|
||||
// Use vnum=1, so adjust all values by one vector length
|
||||
let (length_call, vnum_arg) = if chars.vnum {
|
||||
if chars.is_prf {
|
||||
(quote!(), quote!(, 1))
|
||||
} else {
|
||||
(quote!(let len = #size_fn() as usize;), quote!(, 1))
|
||||
}
|
||||
} else {
|
||||
(quote!(), quote!())
|
||||
};
|
||||
|
||||
let (bases_load, bases_arg) = if let Some(ty) = &chars.gather_bases_type {
|
||||
// Bases is a vector of (sometimes 32-bit) pointers
|
||||
// When we combine bases with an offset/index argument, we load from the data arrays
|
||||
// starting at 1
|
||||
let base_ty = ty.base_type().unwrap();
|
||||
let rust_type = format_ident!("{}", base_ty.rust_repr());
|
||||
let index_fn = format_ident!("svindex_{}", base_ty.acle_notation_repr());
|
||||
let size_in_bytes = chars.load_type.get_size().unwrap() / 8;
|
||||
|
||||
if base_ty.get_size().unwrap() == 32 {
|
||||
// Treat bases as a vector of offsets here - we don't test this without an offset or
|
||||
// index argument
|
||||
(
|
||||
Some(quote!(
|
||||
let bases = #index_fn(0, #size_in_bytes.try_into().unwrap());
|
||||
)),
|
||||
quote!(, bases),
|
||||
)
|
||||
} else {
|
||||
// Treat bases as a vector of pointers
|
||||
let base_fn = format_ident!("svdup_n_{}", base_ty.acle_notation_repr());
|
||||
let data_array = if store.is_some() {
|
||||
format_ident!("storage")
|
||||
} else {
|
||||
format_ident!("{}_DATA", chars.load_type.rust_repr().to_uppercase())
|
||||
};
|
||||
|
||||
let add_fn = format_ident!("svadd_{}_x", base_ty.acle_notation_repr());
|
||||
(
|
||||
Some(quote! {
|
||||
let bases = #base_fn(#data_array.as_ptr() as #rust_type);
|
||||
let offsets = #index_fn(0, #size_in_bytes.try_into().unwrap());
|
||||
let bases = #add_fn(#pred_fn(), bases, offsets);
|
||||
}),
|
||||
quote!(, bases),
|
||||
)
|
||||
}
|
||||
} else {
|
||||
(None, quote!())
|
||||
};
|
||||
|
||||
let index_arg = if let Some(ty) = &chars.gather_index_type {
|
||||
let rust_type = format_ident!("{}", ty.rust_repr());
|
||||
if chars
|
||||
.gather_bases_type
|
||||
.as_ref()
|
||||
.and_then(TypeKind::base_type)
|
||||
.map_or(Err(String::new()), BaseType::get_size)
|
||||
.unwrap()
|
||||
== 32
|
||||
{
|
||||
// Let index be the base of the data array
|
||||
let data_array = if store.is_some() {
|
||||
format_ident!("storage")
|
||||
} else {
|
||||
format_ident!("{}_DATA", chars.load_type.rust_repr().to_uppercase())
|
||||
};
|
||||
let size_in_bytes = chars.load_type.get_size().unwrap() / 8;
|
||||
quote!(, #data_array.as_ptr() as #rust_type / (#size_in_bytes as #rust_type) + 1)
|
||||
} else {
|
||||
quote!(, 1.try_into().unwrap())
|
||||
}
|
||||
} else {
|
||||
quote!()
|
||||
};
|
||||
|
||||
let offset_arg = if let Some(ty) = &chars.gather_offset_type {
|
||||
let size_in_bytes = chars.load_type.get_size().unwrap() / 8;
|
||||
if chars
|
||||
.gather_bases_type
|
||||
.as_ref()
|
||||
.and_then(TypeKind::base_type)
|
||||
.map_or(Err(String::new()), BaseType::get_size)
|
||||
.unwrap()
|
||||
== 32
|
||||
{
|
||||
// Let offset be the base of the data array
|
||||
let rust_type = format_ident!("{}", ty.rust_repr());
|
||||
let data_array = if store.is_some() {
|
||||
format_ident!("storage")
|
||||
} else {
|
||||
format_ident!("{}_DATA", chars.load_type.rust_repr().to_uppercase())
|
||||
};
|
||||
quote!(, #data_array.as_ptr() as #rust_type + #size_in_bytes as #rust_type)
|
||||
} else {
|
||||
quote!(, #size_in_bytes.try_into().unwrap())
|
||||
}
|
||||
} else {
|
||||
quote!()
|
||||
};
|
||||
|
||||
let (offsets_load, offsets_arg) = if let Some(ty) = &chars.gather_offsets_type {
|
||||
// Offsets is a scalable vector of per-element offsets in bytes. We re-use the contiguous
|
||||
// data for this, then multiply to get indices
|
||||
let offsets_fn = format_ident!("svindex_{}", ty.base_type().unwrap().acle_notation_repr());
|
||||
let size_in_bytes = chars.load_type.get_size().unwrap() / 8;
|
||||
(
|
||||
Some(quote! {
|
||||
let offsets = #offsets_fn(0, #size_in_bytes.try_into().unwrap());
|
||||
}),
|
||||
quote!(, offsets),
|
||||
)
|
||||
} else {
|
||||
(None, quote!())
|
||||
};
|
||||
|
||||
let (indices_load, indices_arg) = if let Some(ty) = &chars.gather_indices_type {
|
||||
// There's no need to multiply indices by the load type width
|
||||
let base_ty = ty.base_type().unwrap();
|
||||
let indices_fn = format_ident!("svindex_{}", base_ty.acle_notation_repr());
|
||||
(
|
||||
Some(quote! {
|
||||
let indices = #indices_fn(0, 1);
|
||||
}),
|
||||
quote! {, indices},
|
||||
)
|
||||
} else {
|
||||
(None, quote!())
|
||||
};
|
||||
|
||||
let ptr = if chars.gather_bases_type.is_some() {
|
||||
quote!()
|
||||
} else if chars.is_prf {
|
||||
quote!(, I64_DATA.as_ptr())
|
||||
} else {
|
||||
quote!(, #data_array.as_ptr())
|
||||
};
|
||||
|
||||
let tuple_len = &chars.tuple_len;
|
||||
let expecteds = if chars.is_prf {
|
||||
// No return value for prefetches
|
||||
vec![]
|
||||
} else {
|
||||
(0..*tuple_len)
|
||||
.map(|i| get_expected_range(i, &chars))
|
||||
.collect()
|
||||
};
|
||||
let asserts: Vec<_> =
|
||||
if *tuple_len > 1 {
|
||||
let svget = format_ident!("svget{tuple_len}_{acle_type}");
|
||||
expecteds.iter().enumerate().map(|(i, expected)| {
|
||||
quote! (#assert_fn(#svget::<{ #i as i32 }>(loaded), #expected);)
|
||||
}).collect()
|
||||
} else {
|
||||
expecteds
|
||||
.iter()
|
||||
.map(|expected| quote! (#assert_fn(loaded, #expected);))
|
||||
.collect()
|
||||
};
|
||||
|
||||
let function = if chars.is_prf {
|
||||
if fn_name.contains("gather") && fn_name.contains("base") && !fn_name.starts_with("svprf_")
|
||||
{
|
||||
// svprf(b|h|w|d)_gather base intrinsics do not have a generic type parameter
|
||||
quote!(#fn_ident::<{ svprfop::SV_PLDL1KEEP }>)
|
||||
} else {
|
||||
quote!(#fn_ident::<{ svprfop::SV_PLDL1KEEP }, i64>)
|
||||
}
|
||||
} else {
|
||||
quote!(#fn_ident)
|
||||
};
|
||||
|
||||
let octaword_guard = if chars.replicate_width == Some(256) {
|
||||
let msg = format!("Skipping {test_name} due to SVE vector length");
|
||||
quote! {
|
||||
if svcntb() < 32 {
|
||||
println!(#msg);
|
||||
return;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
quote!()
|
||||
};
|
||||
|
||||
let feats = load.target_features.join(",");
|
||||
|
||||
if let Some(store) = store {
|
||||
let data_init = if *tuple_len == 1 {
|
||||
quote!(#(#expecteds)*)
|
||||
} else {
|
||||
let create = format_ident!("svcreate{tuple_len}_{acle_type}");
|
||||
quote!(#create(#(#expecteds),*))
|
||||
};
|
||||
let input = store.input.types.get(0).unwrap().get(0).unwrap();
|
||||
let store_type = input
|
||||
.get(store.test.get_typeset_index().unwrap())
|
||||
.and_then(InputType::typekind)
|
||||
.and_then(TypeKind::base_type)
|
||||
.unwrap();
|
||||
|
||||
let store_type = format_ident!("{}", store_type.rust_repr());
|
||||
let storage_len = NUM_VECS * VL_MAX_BITS / chars.load_type.get_size()? as usize;
|
||||
let store_fn = format_ident!("{}", store.signature.fn_name().to_string());
|
||||
let load_type = format_ident!("{}", chars.load_type.rust_repr());
|
||||
let (store_ptr, store_mut_ptr) = if chars.gather_bases_type.is_none() {
|
||||
(
|
||||
quote!(, storage.as_ptr() as *const #load_type),
|
||||
quote!(, storage.as_mut_ptr()),
|
||||
)
|
||||
} else {
|
||||
(quote!(), quote!())
|
||||
};
|
||||
let args = quote!(#pred_fn() #store_ptr #vnum_arg #bases_arg #offset_arg #index_arg #offsets_arg #indices_arg);
|
||||
let call = if chars.uses_ffr {
|
||||
// Doing a normal load first maximises the number of elements our ff/nf test loads
|
||||
let non_ffr_fn_name = format_ident!(
|
||||
"{}",
|
||||
fn_name
|
||||
.replace("svldff1", "svld1")
|
||||
.replace("svldnf1", "svld1")
|
||||
);
|
||||
quote! {
|
||||
svsetffr();
|
||||
let _ = #non_ffr_fn_name(#args);
|
||||
let loaded = #function(#args);
|
||||
}
|
||||
} else {
|
||||
// Note that the FFR must be set for all tests as the assert functions mask against it
|
||||
quote! {
|
||||
svsetffr();
|
||||
let loaded = #function(#args);
|
||||
}
|
||||
};
|
||||
|
||||
Ok(quote! {
|
||||
#[simd_test(enable = #feats)]
|
||||
unsafe fn #test_name() {
|
||||
#octaword_guard
|
||||
#length_call
|
||||
let mut storage = [0 as #store_type; #storage_len];
|
||||
let data = #data_init;
|
||||
#bases_load
|
||||
#offsets_load
|
||||
#indices_load
|
||||
|
||||
#store_fn(#pred_fn() #store_mut_ptr #vnum_arg #bases_arg #offset_arg #index_arg #offsets_arg #indices_arg, data);
|
||||
for (i, &val) in storage.iter().enumerate() {
|
||||
assert!(val == 0 as #store_type || val == i as #store_type);
|
||||
}
|
||||
|
||||
#call
|
||||
#(#asserts)*
|
||||
|
||||
}
|
||||
})
|
||||
} else {
|
||||
let args = quote!(#pred_fn() #ptr #vnum_arg #bases_arg #offset_arg #index_arg #offsets_arg #indices_arg);
|
||||
let call = if chars.uses_ffr {
|
||||
// Doing a normal load first maximises the number of elements our ff/nf test loads
|
||||
let non_ffr_fn_name = format_ident!(
|
||||
"{}",
|
||||
fn_name
|
||||
.replace("svldff1", "svld1")
|
||||
.replace("svldnf1", "svld1")
|
||||
);
|
||||
quote! {
|
||||
svsetffr();
|
||||
let _ = #non_ffr_fn_name(#args);
|
||||
let loaded = #function(#args);
|
||||
}
|
||||
} else {
|
||||
// Note that the FFR must be set for all tests as the assert functions mask against it
|
||||
quote! {
|
||||
svsetffr();
|
||||
let loaded = #function(#args);
|
||||
}
|
||||
};
|
||||
Ok(quote! {
|
||||
#[simd_test(enable = #feats)]
|
||||
unsafe fn #test_name() {
|
||||
#octaword_guard
|
||||
#bases_load
|
||||
#offsets_load
|
||||
#indices_load
|
||||
#call
|
||||
#length_call
|
||||
|
||||
#(#asserts)*
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Assumes chars.ret_type is not None
|
||||
fn get_expected_range(tuple_idx: usize, chars: &LdIntrCharacteristics) -> proc_macro2::TokenStream {
|
||||
// vnum=1
|
||||
let vnum_adjust = if chars.vnum { quote!(len+) } else { quote!() };
|
||||
|
||||
let bases_adjust =
|
||||
(chars.gather_index_type.is_some() || chars.gather_offset_type.is_some()) as usize;
|
||||
|
||||
let tuple_len = chars.tuple_len;
|
||||
let size = chars
|
||||
.ret_type
|
||||
.as_ref()
|
||||
.and_then(TypeKind::base_type)
|
||||
.unwrap_or(&chars.load_type)
|
||||
.get_size()
|
||||
.unwrap() as usize;
|
||||
|
||||
if chars.replicate_width == Some(128) {
|
||||
// svld1rq
|
||||
let ty_rust = format_ident!(
|
||||
"{}",
|
||||
chars
|
||||
.ret_type
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.base_type()
|
||||
.unwrap()
|
||||
.rust_repr()
|
||||
);
|
||||
let args: Vec<_> = (0..(128 / size)).map(|i| quote!(#i as #ty_rust)).collect();
|
||||
let dup = format_ident!(
|
||||
"svdupq_n_{}",
|
||||
chars.ret_type.as_ref().unwrap().acle_notation_repr()
|
||||
);
|
||||
quote!(#dup(#(#args,)*))
|
||||
} else if chars.replicate_width == Some(256) {
|
||||
// svld1ro - we use two interleaved svdups to create a repeating 256-bit pattern
|
||||
let ty_rust = format_ident!(
|
||||
"{}",
|
||||
chars
|
||||
.ret_type
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.base_type()
|
||||
.unwrap()
|
||||
.rust_repr()
|
||||
);
|
||||
let ret_acle = chars.ret_type.as_ref().unwrap().acle_notation_repr();
|
||||
let args: Vec<_> = (0..(128 / size)).map(|i| quote!(#i as #ty_rust)).collect();
|
||||
let args2: Vec<_> = ((128 / size)..(256 / size))
|
||||
.map(|i| quote!(#i as #ty_rust))
|
||||
.collect();
|
||||
let dup = format_ident!("svdupq_n_{ret_acle}");
|
||||
let interleave = format_ident!("svtrn1q_{ret_acle}");
|
||||
quote!(#interleave(#dup(#(#args,)*), #dup(#(#args2,)*)))
|
||||
} else {
|
||||
let start = bases_adjust + tuple_idx;
|
||||
if chars
|
||||
.ret_type
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.base_type()
|
||||
.unwrap()
|
||||
.is_float()
|
||||
{
|
||||
// Use svcvt to create a linear sequence of floats
|
||||
let cvt_fn = format_ident!("svcvt_f{size}_s{size}_x");
|
||||
let pred_fn = format_ident!("svptrue_b{size}");
|
||||
let svindex_fn = format_ident!("svindex_s{size}");
|
||||
quote! { #cvt_fn(#pred_fn(), #svindex_fn((#vnum_adjust #start).try_into().unwrap(), #tuple_len.try_into().unwrap()))}
|
||||
} else {
|
||||
let ret_acle = chars.ret_type.as_ref().unwrap().acle_notation_repr();
|
||||
let svindex = format_ident!("svindex_{ret_acle}");
|
||||
quote!(#svindex((#vnum_adjust #start).try_into().unwrap(), #tuple_len.try_into().unwrap()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct LdIntrCharacteristics {
|
||||
// The data type to load from (not necessarily the data type returned)
|
||||
load_type: BaseType,
|
||||
// The data type to return (None for unit)
|
||||
ret_type: Option<TypeKind>,
|
||||
// The size of tuple to load/store
|
||||
tuple_len: usize,
|
||||
// Whether a vnum argument is present
|
||||
vnum: bool,
|
||||
// Is the intrinsic first/non-faulting?
|
||||
uses_ffr: bool,
|
||||
// Is it a prefetch?
|
||||
is_prf: bool,
|
||||
// The size of data loaded with svld1ro/q intrinsics
|
||||
replicate_width: Option<usize>,
|
||||
// Scalable vector of pointers to load from
|
||||
gather_bases_type: Option<TypeKind>,
|
||||
// Scalar offset, paired with bases
|
||||
gather_offset_type: Option<TypeKind>,
|
||||
// Scalar index, paired with bases
|
||||
gather_index_type: Option<TypeKind>,
|
||||
// Scalable vector of offsets
|
||||
gather_offsets_type: Option<TypeKind>,
|
||||
// Scalable vector of indices
|
||||
gather_indices_type: Option<TypeKind>,
|
||||
}
|
||||
|
||||
impl LdIntrCharacteristics {
|
||||
fn new(intr: &Intrinsic) -> Result<LdIntrCharacteristics, String> {
|
||||
let input = intr.input.types.get(0).unwrap().get(0).unwrap();
|
||||
let load_type = input
|
||||
.get(intr.test.get_typeset_index().unwrap())
|
||||
.and_then(InputType::typekind)
|
||||
.and_then(TypeKind::base_type)
|
||||
.unwrap();
|
||||
|
||||
let ret_type = intr.signature.return_type.clone();
|
||||
|
||||
let name = intr.signature.fn_name().to_string();
|
||||
let tuple_len = name
|
||||
.chars()
|
||||
.find(|c| c.is_numeric())
|
||||
.and_then(|c| c.to_digit(10))
|
||||
.unwrap_or(1) as usize;
|
||||
|
||||
let uses_ffr = name.starts_with("svldff") || name.starts_with("svldnf");
|
||||
|
||||
let is_prf = name.starts_with("svprf");
|
||||
|
||||
let replicate_width = if name.starts_with("svld1ro") {
|
||||
Some(256)
|
||||
} else if name.starts_with("svld1rq") {
|
||||
Some(128)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let get_ty_of_arg = |name: &str| {
|
||||
intr.signature
|
||||
.arguments
|
||||
.iter()
|
||||
.find(|a| a.name.to_string() == name)
|
||||
.map(|a| a.kind.clone())
|
||||
};
|
||||
|
||||
let gather_bases_type = get_ty_of_arg("bases");
|
||||
let gather_offset_type = get_ty_of_arg("offset");
|
||||
let gather_index_type = get_ty_of_arg("index");
|
||||
let gather_offsets_type = get_ty_of_arg("offsets");
|
||||
let gather_indices_type = get_ty_of_arg("indices");
|
||||
|
||||
Ok(LdIntrCharacteristics {
|
||||
load_type: *load_type,
|
||||
ret_type,
|
||||
tuple_len,
|
||||
vnum: name.contains("vnum"),
|
||||
uses_ffr,
|
||||
is_prf,
|
||||
replicate_width,
|
||||
gather_bases_type,
|
||||
gather_offset_type,
|
||||
gather_index_type,
|
||||
gather_offsets_type,
|
||||
gather_indices_type,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
lazy_static! {
|
||||
static ref PREAMBLE: String = format!(
|
||||
r#"#![allow(unused)]
|
||||
|
||||
use super::*;
|
||||
use std::boxed::Box;
|
||||
use std::convert::{{TryFrom, TryInto}};
|
||||
use std::sync::LazyLock;
|
||||
use std::vec::Vec;
|
||||
use stdarch_test::simd_test;
|
||||
|
||||
static F32_DATA: LazyLock<[f32; {LEN_F32} * {NUM_VECS}]> = LazyLock::new(|| {{
|
||||
(0..{LEN_F32} * {NUM_VECS})
|
||||
.map(|i| i as f32)
|
||||
.collect::<Vec<_>>()
|
||||
.try_into()
|
||||
.expect("f32 data incorrectly initialised")
|
||||
}});
|
||||
static F64_DATA: LazyLock<[f64; {LEN_F64} * {NUM_VECS}]> = LazyLock::new(|| {{
|
||||
(0..{LEN_F64} * {NUM_VECS})
|
||||
.map(|i| i as f64)
|
||||
.collect::<Vec<_>>()
|
||||
.try_into()
|
||||
.expect("f64 data incorrectly initialised")
|
||||
}});
|
||||
static I8_DATA: LazyLock<[i8; {LEN_I8} * {NUM_VECS}]> = LazyLock::new(|| {{
|
||||
(0..{LEN_I8} * {NUM_VECS})
|
||||
.map(|i| ((i + 128) % 256 - 128) as i8)
|
||||
.collect::<Vec<_>>()
|
||||
.try_into()
|
||||
.expect("i8 data incorrectly initialised")
|
||||
}});
|
||||
static I16_DATA: LazyLock<[i16; {LEN_I16} * {NUM_VECS}]> = LazyLock::new(|| {{
|
||||
(0..{LEN_I16} * {NUM_VECS})
|
||||
.map(|i| i as i16)
|
||||
.collect::<Vec<_>>()
|
||||
.try_into()
|
||||
.expect("i16 data incorrectly initialised")
|
||||
}});
|
||||
static I32_DATA: LazyLock<[i32; {LEN_I32} * {NUM_VECS}]> = LazyLock::new(|| {{
|
||||
(0..{LEN_I32} * {NUM_VECS})
|
||||
.map(|i| i as i32)
|
||||
.collect::<Vec<_>>()
|
||||
.try_into()
|
||||
.expect("i32 data incorrectly initialised")
|
||||
}});
|
||||
static I64_DATA: LazyLock<[i64; {LEN_I64} * {NUM_VECS}]> = LazyLock::new(|| {{
|
||||
(0..{LEN_I64} * {NUM_VECS})
|
||||
.map(|i| i as i64)
|
||||
.collect::<Vec<_>>()
|
||||
.try_into()
|
||||
.expect("i64 data incorrectly initialised")
|
||||
}});
|
||||
static U8_DATA: LazyLock<[u8; {LEN_U8} * {NUM_VECS}]> = LazyLock::new(|| {{
|
||||
(0..{LEN_U8} * {NUM_VECS})
|
||||
.map(|i| i as u8)
|
||||
.collect::<Vec<_>>()
|
||||
.try_into()
|
||||
.expect("u8 data incorrectly initialised")
|
||||
}});
|
||||
static U16_DATA: LazyLock<[u16; {LEN_U16} * {NUM_VECS}]> = LazyLock::new(|| {{
|
||||
(0..{LEN_U16} * {NUM_VECS})
|
||||
.map(|i| i as u16)
|
||||
.collect::<Vec<_>>()
|
||||
.try_into()
|
||||
.expect("u16 data incorrectly initialised")
|
||||
}});
|
||||
static U32_DATA: LazyLock<[u32; {LEN_U32} * {NUM_VECS}]> = LazyLock::new(|| {{
|
||||
(0..{LEN_U32} * {NUM_VECS})
|
||||
.map(|i| i as u32)
|
||||
.collect::<Vec<_>>()
|
||||
.try_into()
|
||||
.expect("u32 data incorrectly initialised")
|
||||
}});
|
||||
static U64_DATA: LazyLock<[u64; {LEN_U64} * {NUM_VECS}]> = LazyLock::new(|| {{
|
||||
(0..{LEN_U64} * {NUM_VECS})
|
||||
.map(|i| i as u64)
|
||||
.collect::<Vec<_>>()
|
||||
.try_into()
|
||||
.expect("u64 data incorrectly initialised")
|
||||
}});
|
||||
|
||||
#[target_feature(enable = "sve")]
|
||||
fn assert_vector_matches_f32(vector: svfloat32_t, expected: svfloat32_t) {{
|
||||
let defined = svrdffr();
|
||||
assert!(svptest_first(svptrue_b32(), defined));
|
||||
let cmp = svcmpne_f32(defined, vector, expected);
|
||||
assert!(!svptest_any(defined, cmp))
|
||||
}}
|
||||
|
||||
#[target_feature(enable = "sve")]
|
||||
fn assert_vector_matches_f64(vector: svfloat64_t, expected: svfloat64_t) {{
|
||||
let defined = svrdffr();
|
||||
assert!(svptest_first(svptrue_b64(), defined));
|
||||
let cmp = svcmpne_f64(defined, vector, expected);
|
||||
assert!(!svptest_any(defined, cmp))
|
||||
}}
|
||||
|
||||
#[target_feature(enable = "sve")]
|
||||
fn assert_vector_matches_i8(vector: svint8_t, expected: svint8_t) {{
|
||||
let defined = svrdffr();
|
||||
assert!(svptest_first(svptrue_b8(), defined));
|
||||
let cmp = svcmpne_s8(defined, vector, expected);
|
||||
assert!(!svptest_any(defined, cmp))
|
||||
}}
|
||||
|
||||
#[target_feature(enable = "sve")]
|
||||
fn assert_vector_matches_i16(vector: svint16_t, expected: svint16_t) {{
|
||||
let defined = svrdffr();
|
||||
assert!(svptest_first(svptrue_b16(), defined));
|
||||
let cmp = svcmpne_s16(defined, vector, expected);
|
||||
assert!(!svptest_any(defined, cmp))
|
||||
}}
|
||||
|
||||
#[target_feature(enable = "sve")]
|
||||
fn assert_vector_matches_i32(vector: svint32_t, expected: svint32_t) {{
|
||||
let defined = svrdffr();
|
||||
assert!(svptest_first(svptrue_b32(), defined));
|
||||
let cmp = svcmpne_s32(defined, vector, expected);
|
||||
assert!(!svptest_any(defined, cmp))
|
||||
}}
|
||||
|
||||
#[target_feature(enable = "sve")]
|
||||
fn assert_vector_matches_i64(vector: svint64_t, expected: svint64_t) {{
|
||||
let defined = svrdffr();
|
||||
assert!(svptest_first(svptrue_b64(), defined));
|
||||
let cmp = svcmpne_s64(defined, vector, expected);
|
||||
assert!(!svptest_any(defined, cmp))
|
||||
}}
|
||||
|
||||
#[target_feature(enable = "sve")]
|
||||
fn assert_vector_matches_u8(vector: svuint8_t, expected: svuint8_t) {{
|
||||
let defined = svrdffr();
|
||||
assert!(svptest_first(svptrue_b8(), defined));
|
||||
let cmp = svcmpne_u8(defined, vector, expected);
|
||||
assert!(!svptest_any(defined, cmp))
|
||||
}}
|
||||
|
||||
#[target_feature(enable = "sve")]
|
||||
fn assert_vector_matches_u16(vector: svuint16_t, expected: svuint16_t) {{
|
||||
let defined = svrdffr();
|
||||
assert!(svptest_first(svptrue_b16(), defined));
|
||||
let cmp = svcmpne_u16(defined, vector, expected);
|
||||
assert!(!svptest_any(defined, cmp))
|
||||
}}
|
||||
|
||||
#[target_feature(enable = "sve")]
|
||||
fn assert_vector_matches_u32(vector: svuint32_t, expected: svuint32_t) {{
|
||||
let defined = svrdffr();
|
||||
assert!(svptest_first(svptrue_b32(), defined));
|
||||
let cmp = svcmpne_u32(defined, vector, expected);
|
||||
assert!(!svptest_any(defined, cmp))
|
||||
}}
|
||||
|
||||
#[target_feature(enable = "sve")]
|
||||
fn assert_vector_matches_u64(vector: svuint64_t, expected: svuint64_t) {{
|
||||
let defined = svrdffr();
|
||||
assert!(svptest_first(svptrue_b64(), defined));
|
||||
let cmp = svcmpne_u64(defined, vector, expected);
|
||||
assert!(!svptest_any(defined, cmp))
|
||||
}}
|
||||
"#
|
||||
);
|
||||
}
|
||||
|
||||
lazy_static! {
|
||||
static ref MANUAL_TESTS: String = format!(
|
||||
"#[simd_test(enable = \"sve\")]
|
||||
unsafe fn test_ffr() {{
|
||||
svsetffr();
|
||||
let ffr = svrdffr();
|
||||
assert_vector_matches_u8(svdup_n_u8_z(ffr, 1), svindex_u8(1, 0));
|
||||
let pred = svdupq_n_b8(true, false, true, false, true, false, true, false,
|
||||
true, false, true, false, true, false, true, false);
|
||||
svwrffr(pred);
|
||||
let ffr = svrdffr_z(svptrue_b8());
|
||||
assert_vector_matches_u8(svdup_n_u8_z(ffr, 1), svdup_n_u8_z(pred, 1));
|
||||
}}
|
||||
"
|
||||
);
|
||||
}
|
||||
273
library/stdarch/crates/stdarch-gen2/src/main.rs
Normal file
273
library/stdarch/crates/stdarch-gen2/src/main.rs
Normal file
|
|
@ -0,0 +1,273 @@
|
|||
#![feature(pattern)]
|
||||
|
||||
mod assert_instr;
|
||||
mod context;
|
||||
mod expression;
|
||||
mod input;
|
||||
mod intrinsic;
|
||||
mod load_store_tests;
|
||||
mod matching;
|
||||
mod predicate_forms;
|
||||
mod typekinds;
|
||||
mod wildcards;
|
||||
mod wildstring;
|
||||
|
||||
use intrinsic::Test;
|
||||
use itertools::Itertools;
|
||||
use quote::quote;
|
||||
use std::fs::File;
|
||||
use std::io::Write;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::process::{Command, Stdio};
|
||||
use walkdir::WalkDir;
|
||||
|
||||
fn main() -> Result<(), String> {
|
||||
parse_args()
|
||||
.into_iter()
|
||||
.map(|(filepath, out)| {
|
||||
File::open(&filepath)
|
||||
.map(|f| (f, filepath, out))
|
||||
.map_err(|e| format!("could not read input file: {e}"))
|
||||
})
|
||||
.map(|res| {
|
||||
let (file, filepath, out) = res?;
|
||||
serde_yaml::from_reader(file)
|
||||
.map(|input: input::GeneratorInput| (input, filepath, out))
|
||||
.map_err(|e| format!("could not parse input file: {e}"))
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?
|
||||
.into_iter()
|
||||
.map(|(input, filepath, out)| {
|
||||
let intrinsics = input.intrinsics.into_iter()
|
||||
.map(|intrinsic| intrinsic.generate_variants(&input.ctx))
|
||||
.try_collect()
|
||||
.map(|mut vv: Vec<_>| {
|
||||
vv.sort_by_cached_key(|variants| {
|
||||
variants.first().map_or_else(String::default, |variant| {
|
||||
variant.signature.fn_name().to_string()
|
||||
})
|
||||
});
|
||||
vv.into_iter().flatten().collect_vec()
|
||||
})?;
|
||||
|
||||
let loads = intrinsics.iter()
|
||||
.filter_map(|i| {
|
||||
if matches!(i.test, Test::Load(..)) {
|
||||
Some(i.clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}).collect();
|
||||
let stores = intrinsics.iter()
|
||||
.filter_map(|i| {
|
||||
if matches!(i.test, Test::Store(..)) {
|
||||
Some(i.clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}).collect();
|
||||
load_store_tests::generate_load_store_tests(loads, stores, out.as_ref().map(|o| make_tests_filepath(&filepath, o)).as_ref())?;
|
||||
Ok((
|
||||
input::GeneratorInput {
|
||||
intrinsics,
|
||||
ctx: input.ctx,
|
||||
},
|
||||
filepath,
|
||||
out,
|
||||
))
|
||||
})
|
||||
.try_for_each(
|
||||
|result: context::Result<(input::GeneratorInput, PathBuf, Option<PathBuf>)>| -> context::Result {
|
||||
let (generated, filepath, out) = result?;
|
||||
|
||||
let w = match out {
|
||||
Some(out) => Box::new(
|
||||
File::create(make_output_filepath(&filepath, &out))
|
||||
.map_err(|e| format!("could not create output file: {e}"))?,
|
||||
) as Box<dyn Write>,
|
||||
None => Box::new(std::io::stdout()) as Box<dyn Write>,
|
||||
};
|
||||
|
||||
generate_file(generated, w)
|
||||
.map_err(|e| format!("could not generate output file: {e}"))
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
fn parse_args() -> Vec<(PathBuf, Option<PathBuf>)> {
|
||||
let mut args_it = std::env::args().skip(1);
|
||||
assert!(
|
||||
1 <= args_it.len() && args_it.len() <= 2,
|
||||
"Usage: cargo run -p stdarch-gen2 -- INPUT_DIR [OUTPUT_DIR]"
|
||||
);
|
||||
|
||||
let in_path = Path::new(args_it.next().unwrap().as_str()).to_path_buf();
|
||||
assert!(
|
||||
in_path.exists() && in_path.is_dir(),
|
||||
"invalid path {in_path:#?} given"
|
||||
);
|
||||
|
||||
let out_dir = if let Some(dir) = args_it.next() {
|
||||
let out_path = Path::new(dir.as_str()).to_path_buf();
|
||||
assert!(
|
||||
out_path.exists() && out_path.is_dir(),
|
||||
"invalid path {out_path:#?} given"
|
||||
);
|
||||
Some(out_path)
|
||||
} else {
|
||||
std::env::current_exe()
|
||||
.map(|mut f| {
|
||||
f.pop();
|
||||
f.push("../../crates/core_arch/src/aarch64/");
|
||||
f.exists().then_some(f)
|
||||
})
|
||||
.ok()
|
||||
.flatten()
|
||||
};
|
||||
|
||||
WalkDir::new(in_path)
|
||||
.into_iter()
|
||||
.filter_map(Result::ok)
|
||||
.filter(|f| f.file_type().is_file())
|
||||
.map(|f| (f.into_path(), out_dir.clone()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn generate_file(
|
||||
generated_input: input::GeneratorInput,
|
||||
mut out: Box<dyn Write>,
|
||||
) -> std::io::Result<()> {
|
||||
write!(
|
||||
out,
|
||||
r#"// This code is automatically generated. DO NOT MODIFY.
|
||||
//
|
||||
// Instead, modify `crates/stdarch-gen2/spec/` and run the following command to re-generate this file:
|
||||
//
|
||||
// ```
|
||||
// cargo run --bin=stdarch-gen2 -- crates/stdarch-gen2/spec
|
||||
// ```
|
||||
#![allow(improper_ctypes)]
|
||||
|
||||
#[cfg(test)]
|
||||
use stdarch_test::assert_instr;
|
||||
|
||||
use super::*;{uses_neon}
|
||||
|
||||
"#,
|
||||
uses_neon = generated_input
|
||||
.ctx
|
||||
.uses_neon_types
|
||||
.then_some("\nuse crate::core_arch::arch::aarch64::*;")
|
||||
.unwrap_or_default(),
|
||||
)?;
|
||||
let intrinsics = generated_input.intrinsics;
|
||||
format_code(out, quote! { #(#intrinsics)* })?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn format_code(
|
||||
mut output: impl std::io::Write,
|
||||
input: impl std::fmt::Display,
|
||||
) -> std::io::Result<()> {
|
||||
let proc = Command::new("rustfmt")
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.spawn()?;
|
||||
write!(proc.stdin.as_ref().unwrap(), "{input}")?;
|
||||
output.write_all(proc.wait_with_output()?.stdout.as_slice())
|
||||
}
|
||||
|
||||
/// Derive an output file name from an input file and an output directory.
|
||||
///
|
||||
/// The name is formed by:
|
||||
///
|
||||
/// - ... taking in_filepath.file_name() (dropping all directory components),
|
||||
/// - ... dropping a .yml or .yaml extension (if present),
|
||||
/// - ... then dropping a .spec extension (if present).
|
||||
///
|
||||
/// Panics if the resulting name is empty, or if file_name() is not UTF-8.
|
||||
fn make_output_filepath(in_filepath: &Path, out_dirpath: &Path) -> PathBuf {
|
||||
make_filepath(in_filepath, out_dirpath, |name: &str| format!("{name}.rs"))
|
||||
}
|
||||
|
||||
fn make_tests_filepath(in_filepath: &Path, out_dirpath: &Path) -> PathBuf {
|
||||
make_filepath(in_filepath, out_dirpath, |name: &str| {
|
||||
format!("ld_st_tests_{name}.rs")
|
||||
})
|
||||
}
|
||||
|
||||
fn make_filepath<F: FnOnce(&str) -> String>(
|
||||
in_filepath: &Path,
|
||||
out_dirpath: &Path,
|
||||
name_formatter: F,
|
||||
) -> PathBuf {
|
||||
let mut parts = in_filepath.iter();
|
||||
let name = parts
|
||||
.next_back()
|
||||
.and_then(|f| f.to_str())
|
||||
.expect("Inputs must have valid, UTF-8 file_name()");
|
||||
let dir = parts.next_back().unwrap();
|
||||
|
||||
let name = name
|
||||
.trim_end_matches(".yml")
|
||||
.trim_end_matches(".yaml")
|
||||
.trim_end_matches(".spec");
|
||||
assert!(!name.is_empty());
|
||||
|
||||
let mut output = out_dirpath.to_path_buf();
|
||||
output.push(dir);
|
||||
output.push(name_formatter(name));
|
||||
output
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn infer_output_file() {
|
||||
macro_rules! t {
|
||||
($src:expr, $outdir:expr, $dst:expr) => {
|
||||
let src: PathBuf = $src.iter().collect();
|
||||
let outdir: PathBuf = $outdir.iter().collect();
|
||||
let dst: PathBuf = $dst.iter().collect();
|
||||
assert_eq!(make_output_filepath(&src, &outdir), dst);
|
||||
};
|
||||
}
|
||||
// Documented usage.
|
||||
t!(["x", "NAME.spec.yml"], [""], ["x", "NAME.rs"]);
|
||||
t!(
|
||||
["x", "NAME.spec.yml"],
|
||||
["a", "b"],
|
||||
["a", "b", "x", "NAME.rs"]
|
||||
);
|
||||
t!(
|
||||
["x", "y", "NAME.spec.yml"],
|
||||
["out"],
|
||||
["out", "y", "NAME.rs"]
|
||||
);
|
||||
t!(["x", "NAME.spec.yaml"], ["out"], ["out", "x", "NAME.rs"]);
|
||||
t!(["x", "NAME.spec"], ["out"], ["out", "x", "NAME.rs"]);
|
||||
t!(["x", "NAME.yml"], ["out"], ["out", "x", "NAME.rs"]);
|
||||
t!(["x", "NAME.yaml"], ["out"], ["out", "x", "NAME.rs"]);
|
||||
// Unrecognised extensions get treated as part of the stem.
|
||||
t!(
|
||||
["x", "NAME.spac.yml"],
|
||||
["out"],
|
||||
["out", "x", "NAME.spac.rs"]
|
||||
);
|
||||
t!(["x", "NAME.txt"], ["out"], ["out", "x", "NAME.txt.rs"]);
|
||||
// Always take the top-level directory from the input path
|
||||
t!(
|
||||
["x", "y", "z", "NAME.spec.yml"],
|
||||
["out"],
|
||||
["out", "z", "NAME.rs"]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn infer_output_file_no_stem() {
|
||||
make_output_filepath(Path::new(".spec.yml"), Path::new(""));
|
||||
}
|
||||
}
|
||||
170
library/stdarch/crates/stdarch-gen2/src/matching.rs
Normal file
170
library/stdarch/crates/stdarch-gen2/src/matching.rs
Normal file
|
|
@ -0,0 +1,170 @@
|
|||
use proc_macro2::TokenStream;
|
||||
use quote::ToTokens;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt;
|
||||
|
||||
use crate::context::{self, LocalContext};
|
||||
use crate::typekinds::{BaseType, BaseTypeKind, TypeKind};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(deny_unknown_fields)]
|
||||
pub struct MatchSizeValues<T> {
|
||||
pub default: T,
|
||||
pub byte: Option<T>,
|
||||
pub halfword: Option<T>,
|
||||
pub doubleword: Option<T>,
|
||||
}
|
||||
|
||||
impl<T> MatchSizeValues<T> {
|
||||
pub fn get(&mut self, ty: &TypeKind, ctx: &LocalContext) -> context::Result<&T> {
|
||||
let base_ty = if let Some(w) = ty.wildcard() {
|
||||
ctx.provide_type_wildcard(w)?
|
||||
} else {
|
||||
ty.clone()
|
||||
};
|
||||
|
||||
if let BaseType::Sized(_, bitsize) = base_ty.base_type().unwrap() {
|
||||
match (bitsize, &self.byte, &self.halfword, &self.doubleword) {
|
||||
(64, _, _, Some(v)) | (16, _, Some(v), _) | (8, Some(v), _, _) => Ok(v),
|
||||
_ => Ok(&self.default),
|
||||
}
|
||||
} else {
|
||||
Err(format!("cannot match bitsize to unsized type {ty:?}!"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(deny_unknown_fields)]
|
||||
pub struct MatchKindValues<T> {
|
||||
pub default: T,
|
||||
pub float: Option<T>,
|
||||
pub unsigned: Option<T>,
|
||||
}
|
||||
|
||||
impl<T> MatchKindValues<T> {
|
||||
pub fn get(&mut self, ty: &TypeKind, ctx: &LocalContext) -> context::Result<&T> {
|
||||
let base_ty = if let Some(w) = ty.wildcard() {
|
||||
ctx.provide_type_wildcard(w)?
|
||||
} else {
|
||||
ty.clone()
|
||||
};
|
||||
|
||||
match (
|
||||
base_ty.base_type().unwrap().kind(),
|
||||
&self.float,
|
||||
&self.unsigned,
|
||||
) {
|
||||
(BaseTypeKind::Float, Some(v), _) | (BaseTypeKind::UInt, _, Some(v)) => Ok(v),
|
||||
_ => Ok(&self.default),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged, deny_unknown_fields)]
|
||||
pub enum SizeMatchable<T> {
|
||||
Matched(T),
|
||||
Unmatched {
|
||||
match_size: Option<TypeKind>,
|
||||
#[serde(flatten)]
|
||||
values: MatchSizeValues<Box<T>>,
|
||||
},
|
||||
}
|
||||
|
||||
impl<T: Clone> SizeMatchable<T> {
|
||||
pub fn perform_match(&mut self, ctx: &LocalContext) -> context::Result {
|
||||
match self {
|
||||
Self::Unmatched {
|
||||
match_size: None,
|
||||
values: MatchSizeValues { default, .. },
|
||||
} => *self = Self::Matched(*default.to_owned()),
|
||||
Self::Unmatched {
|
||||
match_size: Some(ty),
|
||||
values,
|
||||
} => *self = Self::Matched(*values.get(ty, ctx)?.to_owned()),
|
||||
_ => {}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: fmt::Debug> AsRef<T> for SizeMatchable<T> {
|
||||
fn as_ref(&self) -> &T {
|
||||
if let SizeMatchable::Matched(v) = self {
|
||||
v
|
||||
} else {
|
||||
panic!("no match for {self:?} was performed");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: fmt::Debug> AsMut<T> for SizeMatchable<T> {
|
||||
fn as_mut(&mut self) -> &mut T {
|
||||
if let SizeMatchable::Matched(v) = self {
|
||||
v
|
||||
} else {
|
||||
panic!("no match for {self:?} was performed");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: fmt::Debug + ToTokens> ToTokens for SizeMatchable<T> {
|
||||
fn to_tokens(&self, tokens: &mut TokenStream) {
|
||||
self.as_ref().to_tokens(tokens)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged, deny_unknown_fields)]
|
||||
pub enum KindMatchable<T> {
|
||||
Matched(T),
|
||||
Unmatched {
|
||||
match_kind: Option<TypeKind>,
|
||||
#[serde(flatten)]
|
||||
values: MatchKindValues<Box<T>>,
|
||||
},
|
||||
}
|
||||
|
||||
impl<T: Clone> KindMatchable<T> {
|
||||
pub fn perform_match(&mut self, ctx: &LocalContext) -> context::Result {
|
||||
match self {
|
||||
Self::Unmatched {
|
||||
match_kind: None,
|
||||
values: MatchKindValues { default, .. },
|
||||
} => *self = Self::Matched(*default.to_owned()),
|
||||
Self::Unmatched {
|
||||
match_kind: Some(ty),
|
||||
values,
|
||||
} => *self = Self::Matched(*values.get(ty, ctx)?.to_owned()),
|
||||
_ => {}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: fmt::Debug> AsRef<T> for KindMatchable<T> {
|
||||
fn as_ref(&self) -> &T {
|
||||
if let KindMatchable::Matched(v) = self {
|
||||
v
|
||||
} else {
|
||||
panic!("no match for {self:?} was performed");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: fmt::Debug> AsMut<T> for KindMatchable<T> {
|
||||
fn as_mut(&mut self) -> &mut T {
|
||||
if let KindMatchable::Matched(v) = self {
|
||||
v
|
||||
} else {
|
||||
panic!("no match for {self:?} was performed");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: fmt::Debug + ToTokens> ToTokens for KindMatchable<T> {
|
||||
fn to_tokens(&self, tokens: &mut TokenStream) {
|
||||
self.as_ref().to_tokens(tokens)
|
||||
}
|
||||
}
|
||||
249
library/stdarch/crates/stdarch-gen2/src/predicate_forms.rs
Normal file
249
library/stdarch/crates/stdarch-gen2/src/predicate_forms.rs
Normal file
|
|
@ -0,0 +1,249 @@
|
|||
use serde::{Deserialize, Serialize};
|
||||
use serde_with::{DeserializeFromStr, SerializeDisplay};
|
||||
use std::fmt;
|
||||
use std::str::FromStr;
|
||||
|
||||
use crate::context;
|
||||
use crate::expression::{Expression, FnCall, IdentifierType};
|
||||
use crate::intrinsic::Intrinsic;
|
||||
use crate::typekinds::{ToRepr, TypeKind};
|
||||
use crate::wildcards::Wildcard;
|
||||
use crate::wildstring::WildString;
|
||||
|
||||
const ZEROING_SUFFIX: &str = "_z";
|
||||
const MERGING_SUFFIX: &str = "_m";
|
||||
const DONT_CARE_SUFFIX: &str = "_x";
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum ZeroingMethod {
|
||||
/// Drop the specified argument and replace it with a zeroinitializer
|
||||
Drop { drop: WildString },
|
||||
/// Apply zero selection to the specified variable when zeroing
|
||||
Select { select: WildString },
|
||||
}
|
||||
|
||||
impl PartialOrd for ZeroingMethod {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
|
||||
Some(self.cmp(other))
|
||||
}
|
||||
}
|
||||
|
||||
impl Ord for ZeroingMethod {
|
||||
fn cmp(&self, _: &Self) -> std::cmp::Ordering {
|
||||
std::cmp::Ordering::Equal
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
|
||||
pub enum DontCareMethod {
|
||||
#[default]
|
||||
Inferred,
|
||||
AsZeroing,
|
||||
AsMerging,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, PartialEq, Eq, Deserialize, Serialize)]
|
||||
pub struct PredicationMethods {
|
||||
/// Zeroing method, if the zeroing predicate form is used
|
||||
#[serde(default)]
|
||||
pub zeroing_method: Option<ZeroingMethod>,
|
||||
/// Don't care method, if the don't care predicate form is used
|
||||
#[serde(default)]
|
||||
pub dont_care_method: DontCareMethod,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
|
||||
pub enum PredicateForm {
|
||||
/// Enables merging predicate form
|
||||
Merging,
|
||||
/// Enables "don't care" predicate form.
|
||||
DontCare(DontCareMethod),
|
||||
/// Enables zeroing predicate form. If LLVM zeroselection is performed, then
|
||||
/// set the `select` field to the variable that gets set. Otherwise set the
|
||||
/// `drop` field if the zeroinitializer replaces a predicate when merging.
|
||||
Zeroing(ZeroingMethod),
|
||||
}
|
||||
|
||||
impl PredicateForm {
|
||||
pub fn get_suffix(&self) -> &'static str {
|
||||
match self {
|
||||
PredicateForm::Zeroing { .. } => ZEROING_SUFFIX,
|
||||
PredicateForm::Merging => MERGING_SUFFIX,
|
||||
PredicateForm::DontCare { .. } => DONT_CARE_SUFFIX,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn make_zeroinitializer(ty: &TypeKind) -> Expression {
|
||||
FnCall::new_expression(
|
||||
format!("svdup_n_{}", ty.acle_notation_repr())
|
||||
.parse()
|
||||
.unwrap(),
|
||||
vec![if ty.base_type().unwrap().is_float() {
|
||||
Expression::FloatConstant(0.0)
|
||||
} else {
|
||||
Expression::IntConstant(0)
|
||||
}],
|
||||
)
|
||||
}
|
||||
|
||||
pub fn make_zeroselector(pg_var: WildString, op_var: WildString, ty: &TypeKind) -> Expression {
|
||||
FnCall::new_expression(
|
||||
format!("svsel_{}", ty.acle_notation_repr())
|
||||
.parse()
|
||||
.unwrap(),
|
||||
vec![
|
||||
Expression::Identifier(pg_var, IdentifierType::Variable),
|
||||
Expression::Identifier(op_var, IdentifierType::Variable),
|
||||
Self::make_zeroinitializer(ty),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
pub fn post_build(&self, intrinsic: &mut Intrinsic) -> context::Result {
|
||||
// Drop the argument
|
||||
match self {
|
||||
PredicateForm::Zeroing(ZeroingMethod::Drop { drop: drop_var }) => {
|
||||
intrinsic.signature.drop_argument(drop_var)?
|
||||
}
|
||||
PredicateForm::DontCare(DontCareMethod::AsZeroing) => {
|
||||
if let ZeroingMethod::Drop { drop } = intrinsic
|
||||
.input
|
||||
.predication_methods
|
||||
.zeroing_method
|
||||
.to_owned()
|
||||
.ok_or_else(|| {
|
||||
"DontCareMethod::AsZeroing without zeroing method.".to_string()
|
||||
})?
|
||||
{
|
||||
intrinsic.signature.drop_argument(&drop)?
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn infer_dont_care(mask: &PredicationMask, methods: &PredicationMethods) -> PredicateForm {
|
||||
let method = if methods.dont_care_method == DontCareMethod::Inferred {
|
||||
if mask.has_zeroing()
|
||||
&& matches!(methods.zeroing_method, Some(ZeroingMethod::Drop { .. }))
|
||||
{
|
||||
DontCareMethod::AsZeroing
|
||||
} else {
|
||||
DontCareMethod::AsMerging
|
||||
}
|
||||
} else {
|
||||
methods.dont_care_method
|
||||
};
|
||||
|
||||
PredicateForm::DontCare(method)
|
||||
}
|
||||
|
||||
pub fn compile_list(
|
||||
mask: &PredicationMask,
|
||||
methods: &PredicationMethods,
|
||||
) -> context::Result<Vec<PredicateForm>> {
|
||||
let mut forms = Vec::new();
|
||||
|
||||
if mask.has_merging() {
|
||||
forms.push(PredicateForm::Merging)
|
||||
}
|
||||
|
||||
if mask.has_dont_care() {
|
||||
forms.push(Self::infer_dont_care(mask, methods))
|
||||
}
|
||||
|
||||
if mask.has_zeroing() {
|
||||
if let Some(method) = methods.zeroing_method.to_owned() {
|
||||
forms.push(PredicateForm::Zeroing(method))
|
||||
} else {
|
||||
return Err(
|
||||
"cannot create a zeroing variant without a zeroing method specified!"
|
||||
.to_string(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(forms)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(
|
||||
Debug, Clone, Copy, Default, PartialEq, Eq, Hash, DeserializeFromStr, SerializeDisplay,
|
||||
)]
|
||||
pub struct PredicationMask {
|
||||
/// Merging
|
||||
m: bool,
|
||||
/// Don't care
|
||||
x: bool,
|
||||
/// Zeroing
|
||||
z: bool,
|
||||
}
|
||||
|
||||
impl PredicationMask {
|
||||
pub fn has_merging(&self) -> bool {
|
||||
self.m
|
||||
}
|
||||
|
||||
pub fn has_dont_care(&self) -> bool {
|
||||
self.x
|
||||
}
|
||||
|
||||
pub fn has_zeroing(&self) -> bool {
|
||||
self.z
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for PredicationMask {
|
||||
type Err = String;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
let mut result = Self::default();
|
||||
for kind in s.bytes() {
|
||||
match kind {
|
||||
b'm' => result.m = true,
|
||||
b'x' => result.x = true,
|
||||
b'z' => result.z = true,
|
||||
_ => {
|
||||
return Err(format!(
|
||||
"unknown predicate form modifier: {}",
|
||||
char::from(kind)
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if result.m || result.x || result.z {
|
||||
Ok(result)
|
||||
} else {
|
||||
Err("invalid predication mask".to_string())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for PredicationMask {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
self.m.then(|| write!(f, "m")).transpose()?;
|
||||
self.x.then(|| write!(f, "x")).transpose()?;
|
||||
self.z.then(|| write!(f, "z")).transpose().map(|_| ())
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&WildString> for PredicationMask {
|
||||
type Error = String;
|
||||
|
||||
fn try_from(value: &WildString) -> Result<Self, Self::Error> {
|
||||
value
|
||||
.wildcards()
|
||||
.find_map(|w| {
|
||||
if let Wildcard::PredicateForms(mask) = w {
|
||||
Some(*mask)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.ok_or_else(|| "no predicate forms were specified in the name".to_string())
|
||||
}
|
||||
}
|
||||
1024
library/stdarch/crates/stdarch-gen2/src/typekinds.rs
Normal file
1024
library/stdarch/crates/stdarch-gen2/src/typekinds.rs
Normal file
File diff suppressed because it is too large
Load diff
179
library/stdarch/crates/stdarch-gen2/src/wildcards.rs
Normal file
179
library/stdarch/crates/stdarch-gen2/src/wildcards.rs
Normal file
|
|
@ -0,0 +1,179 @@
|
|||
use lazy_static::lazy_static;
|
||||
use regex::Regex;
|
||||
use serde_with::{DeserializeFromStr, SerializeDisplay};
|
||||
use std::fmt;
|
||||
use std::str::FromStr;
|
||||
|
||||
use crate::{
|
||||
predicate_forms::PredicationMask,
|
||||
typekinds::{ToRepr, TypeKind, TypeKindOptions, VectorTupleSize},
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, SerializeDisplay, DeserializeFromStr)]
|
||||
pub enum Wildcard {
|
||||
Type(Option<usize>),
|
||||
/// NEON type derivated by a base type
|
||||
NEONType(Option<usize>, Option<VectorTupleSize>),
|
||||
/// SVE type derivated by a base type
|
||||
SVEType(Option<usize>, Option<VectorTupleSize>),
|
||||
/// Integer representation of bitsize
|
||||
Size(Option<usize>),
|
||||
/// Integer representation of bitsize minus one
|
||||
SizeMinusOne(Option<usize>),
|
||||
/// Literal representation of the bitsize: b(yte), h(half), w(ord) or d(ouble)
|
||||
SizeLiteral(Option<usize>),
|
||||
/// Literal representation of the type kind: f(loat), s(igned), u(nsigned)
|
||||
TypeKind(Option<usize>, Option<TypeKindOptions>),
|
||||
/// Log2 of the size in bytes
|
||||
SizeInBytesLog2(Option<usize>),
|
||||
/// Predicate to be inferred from the specified type
|
||||
Predicate(Option<usize>),
|
||||
/// Predicate to be inferred from the greatest type
|
||||
MaxPredicate,
|
||||
|
||||
Scale(Box<Wildcard>, Box<TypeKind>),
|
||||
|
||||
// Other wildcards
|
||||
LLVMLink,
|
||||
NVariant,
|
||||
/// Predicate forms to use and placeholder for a predicate form function name modifier
|
||||
PredicateForms(PredicationMask),
|
||||
|
||||
/// User-set wildcard through `substitutions`
|
||||
Custom(String),
|
||||
}
|
||||
|
||||
impl Wildcard {
|
||||
pub fn is_nonpredicate_type(&self) -> bool {
|
||||
matches!(
|
||||
self,
|
||||
Wildcard::Type(..) | Wildcard::NEONType(..) | Wildcard::SVEType(..)
|
||||
)
|
||||
}
|
||||
|
||||
pub fn get_typeset_index(&self) -> Option<usize> {
|
||||
match self {
|
||||
Wildcard::Type(idx) | Wildcard::NEONType(idx, ..) | Wildcard::SVEType(idx, ..) => {
|
||||
Some(idx.unwrap_or(0))
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for Wildcard {
|
||||
type Err = String;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
lazy_static! {
|
||||
static ref RE: Regex = Regex::new(r"^(?P<wildcard>\w+?)(?:_x(?P<tuple_size>[2-4]))?(?:\[(?P<index>\d+)\])?(?:\.(?P<modifiers>\w+))?(?:\s+as\s+(?P<scale_to>.*?))?$").unwrap();
|
||||
}
|
||||
|
||||
if let Some(c) = RE.captures(s) {
|
||||
let wildcard_name = &c["wildcard"];
|
||||
let inputset_index = c
|
||||
.name("index")
|
||||
.map(<&str>::from)
|
||||
.map(usize::from_str)
|
||||
.transpose()
|
||||
.map_err(|_| format!("{:#?} is not a valid type index", &c["index"]))?;
|
||||
let tuple_size = c
|
||||
.name("tuple_size")
|
||||
.map(<&str>::from)
|
||||
.map(VectorTupleSize::from_str)
|
||||
.transpose()
|
||||
.map_err(|_| format!("{:#?} is not a valid tuple size", &c["tuple_size"]))?;
|
||||
let modifiers = c.name("modifiers").map(<&str>::from);
|
||||
|
||||
let wildcard = match (wildcard_name, inputset_index, tuple_size, modifiers) {
|
||||
("type", index, None, None) => Ok(Wildcard::Type(index)),
|
||||
("neon_type", index, tuple, None) => Ok(Wildcard::NEONType(index, tuple)),
|
||||
("sve_type", index, tuple, None) => Ok(Wildcard::SVEType(index, tuple)),
|
||||
("size", index, None, None) => Ok(Wildcard::Size(index)),
|
||||
("size_minus_one", index, None, None) => Ok(Wildcard::SizeMinusOne(index)),
|
||||
("size_literal", index, None, None) => Ok(Wildcard::SizeLiteral(index)),
|
||||
("type_kind", index, None, modifiers) => Ok(Wildcard::TypeKind(
|
||||
index,
|
||||
modifiers.map(|modifiers| modifiers.parse()).transpose()?,
|
||||
)),
|
||||
("size_in_bytes_log2", index, None, None) => Ok(Wildcard::SizeInBytesLog2(index)),
|
||||
("predicate", index, None, None) => Ok(Wildcard::Predicate(index)),
|
||||
("max_predicate", None, None, None) => Ok(Wildcard::MaxPredicate),
|
||||
("llvm_link", None, None, None) => Ok(Wildcard::LLVMLink),
|
||||
("_n", None, None, None) => Ok(Wildcard::NVariant),
|
||||
(w, None, None, None) if w.starts_with('_') => {
|
||||
// test for predicate forms
|
||||
let pf_mask = PredicationMask::from_str(&w[1..]);
|
||||
if let Ok(mask) = pf_mask {
|
||||
if mask.has_merging() {
|
||||
Ok(Wildcard::PredicateForms(mask))
|
||||
} else {
|
||||
Err("cannot add predication without a Merging form".to_string())
|
||||
}
|
||||
} else {
|
||||
Err(format!("invalid wildcard `{s:#?}`"))
|
||||
}
|
||||
}
|
||||
(cw, None, None, None) => Ok(Wildcard::Custom(cw.to_string())),
|
||||
_ => Err(format!("invalid wildcard `{s:#?}`")),
|
||||
}?;
|
||||
|
||||
let scale_to = c
|
||||
.name("scale_to")
|
||||
.map(<&str>::from)
|
||||
.map(TypeKind::from_str)
|
||||
.transpose()
|
||||
.map_err(|_| format!("{:#?} is not a valid type", &c["scale_to"]))?;
|
||||
|
||||
if let Some(scale_to) = scale_to {
|
||||
Ok(Wildcard::Scale(Box::new(wildcard), Box::new(scale_to)))
|
||||
} else {
|
||||
Ok(wildcard)
|
||||
}
|
||||
} else {
|
||||
Err(format!("invalid wildcard `{s:#?}`"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for Wildcard {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Self::Type(None) => write!(f, "type"),
|
||||
Self::Type(Some(index)) => write!(f, "type[{index}]"),
|
||||
Self::NEONType(None, None) => write!(f, "neon_type"),
|
||||
Self::NEONType(Some(index), None) => write!(f, "neon_type[{index}]"),
|
||||
Self::NEONType(None, Some(tuple_size)) => write!(f, "neon_type_x{tuple_size}"),
|
||||
Self::NEONType(Some(index), Some(tuple_size)) => {
|
||||
write!(f, "neon_type_x{tuple_size}[{index}]")
|
||||
}
|
||||
Self::SVEType(None, None) => write!(f, "sve_type"),
|
||||
Self::SVEType(Some(index), None) => write!(f, "sve_type[{index}]"),
|
||||
Self::SVEType(None, Some(tuple_size)) => write!(f, "sve_type_x{tuple_size}"),
|
||||
Self::SVEType(Some(index), Some(tuple_size)) => {
|
||||
write!(f, "sve_type_x{tuple_size}[{index}]")
|
||||
}
|
||||
Self::Size(None) => write!(f, "size"),
|
||||
Self::Size(Some(index)) => write!(f, "size[{index}]"),
|
||||
Self::SizeMinusOne(None) => write!(f, "size_minus_one"),
|
||||
Self::SizeMinusOne(Some(index)) => write!(f, "size_minus_one[{index}]"),
|
||||
Self::SizeLiteral(None) => write!(f, "size_literal"),
|
||||
Self::SizeLiteral(Some(index)) => write!(f, "size_literal[{index}]"),
|
||||
Self::TypeKind(None, None) => write!(f, "type_kind"),
|
||||
Self::TypeKind(None, Some(opts)) => write!(f, "type_kind.{opts}"),
|
||||
Self::TypeKind(Some(index), None) => write!(f, "type_kind[{index}]"),
|
||||
Self::TypeKind(Some(index), Some(opts)) => write!(f, "type_kind[{index}].{opts}"),
|
||||
Self::SizeInBytesLog2(None) => write!(f, "size_in_bytes_log2"),
|
||||
Self::SizeInBytesLog2(Some(index)) => write!(f, "size_in_bytes_log2[{index}]"),
|
||||
Self::Predicate(None) => write!(f, "predicate"),
|
||||
Self::Predicate(Some(index)) => write!(f, "predicate[{index}]"),
|
||||
Self::MaxPredicate => write!(f, "max_predicate"),
|
||||
Self::LLVMLink => write!(f, "llvm_link"),
|
||||
Self::NVariant => write!(f, "_n"),
|
||||
Self::PredicateForms(mask) => write!(f, "_{mask}"),
|
||||
|
||||
Self::Scale(wildcard, ty) => write!(f, "{wildcard} as {}", ty.rust_repr()),
|
||||
Self::Custom(cw) => write!(f, "{cw}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
353
library/stdarch/crates/stdarch-gen2/src/wildstring.rs
Normal file
353
library/stdarch/crates/stdarch-gen2/src/wildstring.rs
Normal file
|
|
@ -0,0 +1,353 @@
|
|||
use itertools::Itertools;
|
||||
use proc_macro2::TokenStream;
|
||||
use quote::{quote, ToTokens, TokenStreamExt};
|
||||
use serde_with::{DeserializeFromStr, SerializeDisplay};
|
||||
use std::str::pattern::Pattern;
|
||||
use std::{fmt, str::FromStr};
|
||||
|
||||
use crate::context::LocalContext;
|
||||
use crate::typekinds::{ToRepr, TypeRepr};
|
||||
use crate::wildcards::Wildcard;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum WildStringPart {
|
||||
String(String),
|
||||
Wildcard(Wildcard),
|
||||
}
|
||||
|
||||
/// Wildcard-able string
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Default, SerializeDisplay, DeserializeFromStr)]
|
||||
pub struct WildString(Vec<WildStringPart>);
|
||||
|
||||
impl WildString {
|
||||
pub fn has_wildcards(&self) -> bool {
|
||||
for part in self.0.iter() {
|
||||
if let WildStringPart::Wildcard(..) = part {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
pub fn wildcards(&self) -> impl Iterator<Item = &Wildcard> + '_ {
|
||||
self.0.iter().filter_map(|part| match part {
|
||||
WildStringPart::Wildcard(w) => Some(w),
|
||||
_ => None,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn iter(&self) -> impl Iterator<Item = &WildStringPart> + '_ {
|
||||
self.0.iter()
|
||||
}
|
||||
|
||||
pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut WildStringPart> + '_ {
|
||||
self.0.iter_mut()
|
||||
}
|
||||
|
||||
pub fn starts_with(&self, s2: &str) -> bool {
|
||||
self.to_string().starts_with(s2)
|
||||
}
|
||||
|
||||
pub fn prepend_str(&mut self, s: impl Into<String>) {
|
||||
self.0.insert(0, WildStringPart::String(s.into()))
|
||||
}
|
||||
|
||||
pub fn push_str(&mut self, s: impl Into<String>) {
|
||||
self.0.push(WildStringPart::String(s.into()))
|
||||
}
|
||||
|
||||
pub fn push_wildcard(&mut self, w: Wildcard) {
|
||||
self.0.push(WildStringPart::Wildcard(w))
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.0.is_empty()
|
||||
}
|
||||
|
||||
pub fn replace<'a, P>(&'a self, from: P, to: &str) -> WildString
|
||||
where
|
||||
P: Pattern<'a> + Copy,
|
||||
{
|
||||
WildString(
|
||||
self.0
|
||||
.iter()
|
||||
.map(|part| match part {
|
||||
WildStringPart::String(s) => WildStringPart::String(s.replace(from, to)),
|
||||
part => part.clone(),
|
||||
})
|
||||
.collect_vec(),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn build_acle(&mut self, ctx: &LocalContext) -> Result<(), String> {
|
||||
self.build(ctx, TypeRepr::ACLENotation)
|
||||
}
|
||||
|
||||
pub fn build(&mut self, ctx: &LocalContext, repr: TypeRepr) -> Result<(), String> {
|
||||
self.iter_mut().try_for_each(|wp| -> Result<(), String> {
|
||||
if let WildStringPart::Wildcard(w) = wp {
|
||||
let value = ctx
|
||||
.provide_substitution_wildcard(w)
|
||||
.or_else(|_| ctx.provide_type_wildcard(w).map(|ty| ty.repr(repr)))?;
|
||||
*wp = WildStringPart::String(value);
|
||||
}
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for WildString {
|
||||
fn from(s: String) -> Self {
|
||||
WildString(vec![WildStringPart::String(s)])
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for WildString {
|
||||
type Err = String;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
enum State {
|
||||
Normal { start: usize },
|
||||
Wildcard { start: usize, count: usize },
|
||||
EscapeTokenOpen { start: usize, at: usize },
|
||||
EscapeTokenClose { start: usize, at: usize },
|
||||
}
|
||||
|
||||
let mut ws = WildString::default();
|
||||
match s
|
||||
.char_indices()
|
||||
.try_fold(State::Normal { start: 0 }, |state, (idx, ch)| {
|
||||
match (state, ch) {
|
||||
(State::Normal { start }, '{') => Ok(State::EscapeTokenOpen { start, at: idx }),
|
||||
(State::Normal { start }, '}') => {
|
||||
Ok(State::EscapeTokenClose { start, at: idx })
|
||||
}
|
||||
(State::EscapeTokenOpen { start, at }, '{')
|
||||
| (State::EscapeTokenClose { start, at }, '}') => {
|
||||
if start < at {
|
||||
ws.push_str(&s[start..at])
|
||||
}
|
||||
|
||||
Ok(State::Normal { start: idx })
|
||||
}
|
||||
(State::EscapeTokenOpen { at, .. }, '}') => Err(format!(
|
||||
"empty wildcard given in string {s:?} at position {at}"
|
||||
)),
|
||||
(State::EscapeTokenOpen { start, at }, _) => {
|
||||
if start < at {
|
||||
ws.push_str(&s[start..at])
|
||||
}
|
||||
|
||||
Ok(State::Wildcard {
|
||||
start: idx,
|
||||
count: 0,
|
||||
})
|
||||
}
|
||||
(State::EscapeTokenClose { at, .. }, _) => Err(format!(
|
||||
"closing a non-wildcard/bad escape in string {s:?} at position {at}"
|
||||
)),
|
||||
// Nesting wildcards is only supported for `{foo as {bar}}`, wildcards cannot be
|
||||
// nested at the start of a WildString.
|
||||
(State::Wildcard { start, count }, '{') => Ok(State::Wildcard {
|
||||
start,
|
||||
count: count + 1,
|
||||
}),
|
||||
(State::Wildcard { start, count: 0 }, '}') => {
|
||||
ws.push_wildcard(s[start..idx].parse()?);
|
||||
Ok(State::Normal { start: idx + 1 })
|
||||
}
|
||||
(State::Wildcard { start, count }, '}') => Ok(State::Wildcard {
|
||||
start,
|
||||
count: count - 1,
|
||||
}),
|
||||
(state @ State::Normal { .. }, _) | (state @ State::Wildcard { .. }, _) => {
|
||||
Ok(state)
|
||||
}
|
||||
}
|
||||
})? {
|
||||
State::Normal { start } => {
|
||||
if start < s.len() {
|
||||
ws.push_str(&s[start..]);
|
||||
}
|
||||
|
||||
Ok(ws)
|
||||
}
|
||||
State::EscapeTokenOpen { at, .. } | State::Wildcard { start: at, .. } => Err(format!(
|
||||
"unclosed wildcard in string {s:?} at position {at}"
|
||||
)),
|
||||
State::EscapeTokenClose { at, .. } => Err(format!(
|
||||
"closing a non-wildcard/bad escape in string {s:?} at position {at}"
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for WildString {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"{}",
|
||||
self.0
|
||||
.iter()
|
||||
.map(|part| match part {
|
||||
WildStringPart::String(s) => s.to_owned(),
|
||||
WildStringPart::Wildcard(w) => format!("{{{w}}}"),
|
||||
})
|
||||
.join("")
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl ToTokens for WildString {
|
||||
fn to_tokens(&self, tokens: &mut TokenStream) {
|
||||
assert!(
|
||||
!self.has_wildcards(),
|
||||
"cannot convert string with wildcards {self:?} to TokenStream"
|
||||
);
|
||||
let str = self.to_string();
|
||||
tokens.append_all(quote! { #str })
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::typekinds::*;
|
||||
use crate::wildstring::*;
|
||||
|
||||
#[test]
|
||||
fn test_empty_string() {
|
||||
let ws: WildString = "".parse().unwrap();
|
||||
assert_eq!(ws.0.len(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_plain_string() {
|
||||
let ws: WildString = "plain string".parse().unwrap();
|
||||
assert_eq!(ws.0.len(), 1);
|
||||
assert_eq!(
|
||||
ws,
|
||||
WildString(vec![WildStringPart::String("plain string".to_string())])
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_escaped_curly_brackets() {
|
||||
let ws: WildString = "VALUE = {{value}}".parse().unwrap();
|
||||
assert_eq!(ws.to_string(), "VALUE = {value}");
|
||||
assert!(!ws.has_wildcards());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_escaped_curly_brackets_wildcard() {
|
||||
let ws: WildString = "TYPE = {{{type}}}".parse().unwrap();
|
||||
assert_eq!(ws.to_string(), "TYPE = {{type}}");
|
||||
assert_eq!(ws.0.len(), 4);
|
||||
assert!(ws.has_wildcards());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_wildcard_right_boundary() {
|
||||
let s = "string test {type}";
|
||||
let ws: WildString = s.parse().unwrap();
|
||||
assert_eq!(&ws.to_string(), s);
|
||||
assert!(ws.has_wildcards());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_wildcard_left_boundary() {
|
||||
let s = "{type} string test";
|
||||
let ws: WildString = s.parse().unwrap();
|
||||
assert_eq!(&ws.to_string(), s);
|
||||
assert!(ws.has_wildcards());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_recursive_wildcard() {
|
||||
let s = "string test {type[0] as {type[1]}}";
|
||||
let ws: WildString = s.parse().unwrap();
|
||||
|
||||
assert_eq!(ws.0.len(), 2);
|
||||
assert_eq!(
|
||||
ws,
|
||||
WildString(vec![
|
||||
WildStringPart::String("string test ".to_string()),
|
||||
WildStringPart::Wildcard(Wildcard::Scale(
|
||||
Box::new(Wildcard::Type(Some(0))),
|
||||
Box::new(TypeKind::Wildcard(Wildcard::Type(Some(1)))),
|
||||
))
|
||||
])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scale_wildcard() {
|
||||
let s = "string {type[0] as i8} test";
|
||||
let ws: WildString = s.parse().unwrap();
|
||||
|
||||
assert_eq!(ws.0.len(), 3);
|
||||
assert_eq!(
|
||||
ws,
|
||||
WildString(vec![
|
||||
WildStringPart::String("string ".to_string()),
|
||||
WildStringPart::Wildcard(Wildcard::Scale(
|
||||
Box::new(Wildcard::Type(Some(0))),
|
||||
Box::new(TypeKind::Base(BaseType::Sized(BaseTypeKind::Int, 8))),
|
||||
)),
|
||||
WildStringPart::String(" test".to_string())
|
||||
])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_solitaire_wildcard() {
|
||||
let ws: WildString = "{type}".parse().unwrap();
|
||||
assert_eq!(ws.0.len(), 1);
|
||||
assert_eq!(
|
||||
ws,
|
||||
WildString(vec![WildStringPart::Wildcard(Wildcard::Type(None))])
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_wildcard() {
|
||||
"string {}"
|
||||
.parse::<WildString>()
|
||||
.expect_err("expected parse error");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_open_wildcard_right() {
|
||||
"string {"
|
||||
.parse::<WildString>()
|
||||
.expect_err("expected parse error");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_close_wildcard_right() {
|
||||
"string }"
|
||||
.parse::<WildString>()
|
||||
.expect_err("expected parse error");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_open_wildcard_left() {
|
||||
"{string"
|
||||
.parse::<WildString>()
|
||||
.expect_err("expected parse error");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_close_wildcard_left() {
|
||||
"}string"
|
||||
.parse::<WildString>()
|
||||
.expect_err("expected parse error");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_consecutive_wildcards() {
|
||||
let s = "svprf{size_literal[1]}_gather_{type[0]}{index_or_offset}";
|
||||
let ws: WildString = s.parse().unwrap();
|
||||
assert_eq!(ws.to_string(), s)
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue