Generator for SVE intrinsics.

Co-authored-by: Jamie Cunliffe <Jamie.Cunliffe@arm.com>
Co-authored-by: Jacob Bramley <jacob.bramley@arm.com>
Co-authored-by: Luca Vizzarro <Luca.Vizzarro@arm.com>
Co-authored-by: Adam Gemmell <adam.gemmell@arm.com>
This commit is contained in:
Luca Vizzarro 2023-10-25 14:04:37 +01:00 committed by Amanieu d'Antras
parent 9e24b307df
commit 03e4f2636e
14 changed files with 6197 additions and 0 deletions

View file

@ -3,6 +3,18 @@ resolver = "1"
members = [
"crates/*",
"examples"
"crates/stdarch-verify",
"crates/core_arch",
"crates/std_detect",
"crates/stdarch-gen-arm",
"crates/stdarch-gen-loongarch",
"crates/stdarch-gen",
"crates/stdarch-gen2",
"crates/intrinsic-test",
"examples/"
]
exclude = [
"crates/wasm-assert-instr-tests"
]
[profile.release]

View file

@ -0,0 +1,22 @@
[package]
name = "stdarch-gen2"
version = "0.1.0"
authors = ["Luca Vizzarro <luca.vizzarro@arm.com>",
"Jamie Cunliffe <Jamie.Cunliffe@arm.com>",
"Adam Gemmell <Adam.Gemmell@arm.com",
"Jacob Bramley <jacob.bramley@arm.com>"]
license = "MIT OR Apache-2.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
itertools = "0.10"
lazy_static = "1.4.0"
proc-macro2 = "1.0"
quote = "1.0"
regex = "1.5"
serde = { version = "1.0", features = ["derive"] }
serde_with = "1.14"
serde_yaml = "0.8"
walkdir = "2.3.2"

View file

@ -0,0 +1,372 @@
use proc_macro2::TokenStream;
use quote::{format_ident, quote, ToTokens, TokenStreamExt};
use serde::de::{self, MapAccess, Visitor};
use serde::{ser::SerializeSeq, Deserialize, Deserializer, Serialize};
use std::fmt;
use crate::{
context::{self, Context},
typekinds::{BaseType, BaseTypeKind},
wildstring::WildString,
};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum InstructionAssertion {
Basic(WildString),
WithArgs(WildString, WildString),
}
impl InstructionAssertion {
fn build(&mut self, ctx: &Context) -> context::Result {
match self {
InstructionAssertion::Basic(ws) => ws.build_acle(ctx.local),
InstructionAssertion::WithArgs(ws, args_ws) => [ws, args_ws]
.into_iter()
.try_for_each(|ws| ws.build_acle(ctx.local)),
}
}
}
impl ToTokens for InstructionAssertion {
fn to_tokens(&self, tokens: &mut TokenStream) {
let instr = format_ident!(
"{}",
match self {
Self::Basic(instr) => instr,
Self::WithArgs(instr, _) => instr,
}
.to_string()
);
tokens.append_all(quote! { #instr });
if let Self::WithArgs(_, args) = self {
let ex: TokenStream = args
.to_string()
.parse()
.expect("invalid instruction assertion arguments expression given");
tokens.append_all(quote! {, #ex})
}
}
}
// Asserts that the given instruction is present for the intrinsic of the associated type bitsize.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(remote = "Self")]
pub struct InstructionAssertionMethodForBitsize {
pub default: InstructionAssertion,
pub byte: Option<InstructionAssertion>,
pub halfword: Option<InstructionAssertion>,
pub word: Option<InstructionAssertion>,
pub doubleword: Option<InstructionAssertion>,
}
impl InstructionAssertionMethodForBitsize {
fn build(&mut self, ctx: &Context) -> context::Result {
if let Some(ref mut byte) = self.byte {
byte.build(ctx)?
}
if let Some(ref mut halfword) = self.halfword {
halfword.build(ctx)?
}
if let Some(ref mut word) = self.word {
word.build(ctx)?
}
if let Some(ref mut doubleword) = self.doubleword {
doubleword.build(ctx)?
}
self.default.build(ctx)
}
}
impl Serialize for InstructionAssertionMethodForBitsize {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
match self {
InstructionAssertionMethodForBitsize {
default: InstructionAssertion::Basic(instr),
byte: None,
halfword: None,
word: None,
doubleword: None,
} => serializer.serialize_str(&instr.to_string()),
InstructionAssertionMethodForBitsize {
default: InstructionAssertion::WithArgs(instr, args),
byte: None,
halfword: None,
word: None,
doubleword: None,
} => {
let mut seq = serializer.serialize_seq(Some(2))?;
seq.serialize_element(&instr.to_string())?;
seq.serialize_element(&args.to_string())?;
seq.end()
}
_ => InstructionAssertionMethodForBitsize::serialize(self, serializer),
}
}
}
impl<'de> Deserialize<'de> for InstructionAssertionMethodForBitsize {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct IAMVisitor;
impl<'de> Visitor<'de> for IAMVisitor {
type Value = InstructionAssertionMethodForBitsize;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("array, string or map")
}
fn visit_str<E>(self, value: &str) -> Result<InstructionAssertionMethodForBitsize, E>
where
E: de::Error,
{
Ok(InstructionAssertionMethodForBitsize {
default: InstructionAssertion::Basic(value.parse().map_err(E::custom)?),
byte: None,
halfword: None,
word: None,
doubleword: None,
})
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: de::SeqAccess<'de>,
{
use serde::de::Error;
let make_err =
|| Error::custom("invalid number of arguments passed to assert_instruction");
let instruction = seq.next_element()?.ok_or_else(make_err)?;
let args = seq.next_element()?.ok_or_else(make_err)?;
if let Some(true) = seq.size_hint().map(|len| len > 0) {
Err(make_err())
} else {
Ok(InstructionAssertionMethodForBitsize {
default: InstructionAssertion::WithArgs(instruction, args),
byte: None,
halfword: None,
word: None,
doubleword: None,
})
}
}
fn visit_map<M>(self, map: M) -> Result<InstructionAssertionMethodForBitsize, M::Error>
where
M: MapAccess<'de>,
{
InstructionAssertionMethodForBitsize::deserialize(
de::value::MapAccessDeserializer::new(map),
)
}
}
deserializer.deserialize_any(IAMVisitor)
}
}
/// Asserts that the given instruction is present for the intrinsic of the associated type.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(remote = "Self")]
pub struct InstructionAssertionMethod {
/// Instruction for integer intrinsics
pub default: InstructionAssertionMethodForBitsize,
/// Instruction for floating-point intrinsics (optional)
#[serde(default)]
pub float: Option<InstructionAssertionMethodForBitsize>,
/// Instruction for unsigned integer intrinsics (optional)
#[serde(default)]
pub unsigned: Option<InstructionAssertionMethodForBitsize>,
}
impl InstructionAssertionMethod {
pub(crate) fn build(&mut self, ctx: &Context) -> context::Result {
if let Some(ref mut float) = self.float {
float.build(ctx)?
}
if let Some(ref mut unsigned) = self.unsigned {
unsigned.build(ctx)?
}
self.default.build(ctx)
}
}
impl Serialize for InstructionAssertionMethod {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
match self {
InstructionAssertionMethod {
default:
InstructionAssertionMethodForBitsize {
default: InstructionAssertion::Basic(instr),
byte: None,
halfword: None,
word: None,
doubleword: None,
},
float: None,
unsigned: None,
} => serializer.serialize_str(&instr.to_string()),
InstructionAssertionMethod {
default:
InstructionAssertionMethodForBitsize {
default: InstructionAssertion::WithArgs(instr, args),
byte: None,
halfword: None,
word: None,
doubleword: None,
},
float: None,
unsigned: None,
} => {
let mut seq = serializer.serialize_seq(Some(2))?;
seq.serialize_element(&instr.to_string())?;
seq.serialize_element(&args.to_string())?;
seq.end()
}
_ => InstructionAssertionMethod::serialize(self, serializer),
}
}
}
impl<'de> Deserialize<'de> for InstructionAssertionMethod {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct IAMVisitor;
impl<'de> Visitor<'de> for IAMVisitor {
type Value = InstructionAssertionMethod;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("array, string or map")
}
fn visit_str<E>(self, value: &str) -> Result<InstructionAssertionMethod, E>
where
E: de::Error,
{
Ok(InstructionAssertionMethod {
default: InstructionAssertionMethodForBitsize {
default: InstructionAssertion::Basic(value.parse().map_err(E::custom)?),
byte: None,
halfword: None,
word: None,
doubleword: None,
},
float: None,
unsigned: None,
})
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: de::SeqAccess<'de>,
{
use serde::de::Error;
let make_err =
|| Error::custom("invalid number of arguments passed to assert_instruction");
let instruction = seq.next_element()?.ok_or_else(make_err)?;
let args = seq.next_element()?.ok_or_else(make_err)?;
if let Some(true) = seq.size_hint().map(|len| len > 0) {
Err(make_err())
} else {
Ok(InstructionAssertionMethod {
default: InstructionAssertionMethodForBitsize {
default: InstructionAssertion::WithArgs(instruction, args),
byte: None,
halfword: None,
word: None,
doubleword: None,
},
float: None,
unsigned: None,
})
}
}
fn visit_map<M>(self, map: M) -> Result<InstructionAssertionMethod, M::Error>
where
M: MapAccess<'de>,
{
InstructionAssertionMethod::deserialize(de::value::MapAccessDeserializer::new(map))
}
}
deserializer.deserialize_any(IAMVisitor)
}
}
#[derive(Debug)]
pub struct InstructionAssertionsForBaseType<'a>(
pub &'a Vec<InstructionAssertionMethod>,
pub &'a Option<&'a BaseType>,
);
impl<'a> ToTokens for InstructionAssertionsForBaseType<'a> {
fn to_tokens(&self, tokens: &mut TokenStream) {
self.0.iter().for_each(
|InstructionAssertionMethod {
default,
float,
unsigned,
}| {
let kind = self.1.map(|ty| ty.kind());
let instruction = match (kind, float, unsigned) {
(None, float, unsigned) if float.is_some() || unsigned.is_some() => {
unreachable!(
"cannot determine the base type kind for instruction assertion: {self:#?}")
}
(Some(BaseTypeKind::Float), Some(float), _) => float,
(Some(BaseTypeKind::UInt), _, Some(unsigned)) => unsigned,
_ => default,
};
let bitsize = self.1.and_then(|ty| ty.get_size().ok());
let instruction = match (bitsize, instruction) {
(
Some(8),
InstructionAssertionMethodForBitsize {
byte: Some(byte), ..
},
) => byte,
(
Some(16),
InstructionAssertionMethodForBitsize {
halfword: Some(halfword),
..
},
) => halfword,
(
Some(32),
InstructionAssertionMethodForBitsize {
word: Some(word), ..
},
) => word,
(
Some(64),
InstructionAssertionMethodForBitsize {
doubleword: Some(doubleword),
..
},
) => doubleword,
(_, InstructionAssertionMethodForBitsize { default, .. }) => default,
};
tokens.append_all(quote! { #[cfg_attr(test, assert_instr(#instruction))]})
},
);
}
}

View file

@ -0,0 +1,249 @@
use itertools::Itertools;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::{
expression::Expression,
input::{InputSet, InputType},
intrinsic::{Constraint, Intrinsic, Signature},
matching::SizeMatchable,
predicate_forms::PredicateForm,
typekinds::{ToRepr, TypeKind},
wildcards::Wildcard,
wildstring::WildString,
};
/// Maximum SVE vector size
const SVE_VECTOR_MAX_SIZE: u32 = 2048;
/// Vector register size
const VECTOR_REG_SIZE: u32 = 128;
/// Generator result
pub type Result<T = ()> = std::result::Result<T, String>;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ArchitectureSettings {
#[serde(alias = "arch")]
pub arch_name: String,
pub target_feature: Vec<String>,
#[serde(alias = "llvm_prefix")]
pub llvm_link_prefix: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GlobalContext {
pub arch_cfgs: Vec<ArchitectureSettings>,
#[serde(default)]
pub uses_neon_types: bool,
}
/// Context of an intrinsic group
#[derive(Debug, Clone, Default)]
pub struct GroupContext {
/// LLVM links to target input sets
pub links: HashMap<String, InputSet>,
}
#[derive(Debug, Clone, Copy)]
pub enum VariableType {
Argument,
Internal,
}
#[derive(Debug, Clone)]
pub struct LocalContext {
pub signature: Signature,
pub input: InputSet,
pub substitutions: HashMap<Wildcard, String>,
pub variables: HashMap<String, (TypeKind, VariableType)>,
}
impl LocalContext {
pub fn new(input: InputSet, original: &Intrinsic) -> LocalContext {
LocalContext {
signature: original.signature.clone(),
input,
substitutions: HashMap::new(),
variables: HashMap::new(),
}
}
pub fn provide_type_wildcard(&self, wildcard: &Wildcard) -> Result<TypeKind> {
let err = || format!("wildcard {{{wildcard}}} not found");
let make_neon = |tuple_size| move |ty| TypeKind::make_vector(ty, false, tuple_size);
let make_sve = |tuple_size| move |ty| TypeKind::make_vector(ty, true, tuple_size);
match wildcard {
Wildcard::Type(idx) => self.input.typekind(*idx).ok_or_else(err),
Wildcard::NEONType(idx, tuple_size) => self
.input
.typekind(*idx)
.ok_or_else(err)
.and_then(make_neon(*tuple_size)),
Wildcard::SVEType(idx, tuple_size) => self
.input
.typekind(*idx)
.ok_or_else(err)
.and_then(make_sve(*tuple_size)),
Wildcard::Predicate(idx) => self.input.typekind(*idx).map_or_else(
|| {
if idx.is_none() && self.input.types_len() == 1 {
Err(err())
} else {
Err(format!(
"there is no type at index {} to infer the predicate from",
idx.unwrap_or(0)
))
}
},
|ref ty| TypeKind::make_predicate_from(ty),
),
Wildcard::MaxPredicate => self
.input
.iter()
.filter_map(|arg| arg.typekind())
.max_by(|x, y| {
x.base_type()
.and_then(|bt| bt.get_size().ok())
.unwrap_or(0)
.cmp(&y.base_type().and_then(|bt| bt.get_size().ok()).unwrap_or(0))
})
.map_or_else(
|| Err("there are no types available to infer the predicate from".to_string()),
TypeKind::make_predicate_from,
),
Wildcard::Scale(w, as_ty) => {
let mut ty = self.provide_type_wildcard(w)?;
if let Some(vty) = ty.vector_mut() {
let base_ty = if let Some(w) = as_ty.wildcard() {
*self.provide_type_wildcard(w)?.base_type().unwrap()
} else {
*as_ty.base_type().unwrap()
};
vty.cast_base_type_as(base_ty)
}
Ok(ty)
}
_ => Err(err()),
}
}
pub fn provide_substitution_wildcard(&self, wildcard: &Wildcard) -> Result<String> {
let err = || Err(format!("wildcard {{{wildcard}}} not found"));
match wildcard {
Wildcard::SizeLiteral(idx) => self.input.typekind(*idx)
.map_or_else(err, |ty| Ok(ty.size_literal())),
Wildcard::Size(idx) => self.input.typekind(*idx)
.map_or_else(err, |ty| Ok(ty.size())),
Wildcard::SizeMinusOne(idx) => self.input.typekind(*idx)
.map_or_else(err, |ty| Ok((ty.size().parse::<i32>().unwrap()-1).to_string())),
Wildcard::SizeInBytesLog2(idx) => self.input.typekind(*idx)
.map_or_else(err, |ty| Ok(ty.size_in_bytes_log2())),
Wildcard::NVariant if self.substitutions.get(wildcard).is_none() => Ok(String::new()),
Wildcard::TypeKind(idx, opts) => {
self.input.typekind(*idx)
.map_or_else(err, |ty| {
let literal = if let Some(opts) = opts {
opts.contains(ty.base_type().map(|bt| *bt.kind()).ok_or_else(|| {
format!("cannot retrieve a type literal out of {ty}")
})?)
.then(|| ty.type_kind())
.unwrap_or_default()
} else {
ty.type_kind()
};
Ok(literal)
})
}
Wildcard::PredicateForms(_) => self
.input
.iter()
.find_map(|arg| {
if let InputType::PredicateForm(pf) = arg {
Some(pf.get_suffix().to_string())
} else {
None
}
})
.ok_or_else(|| unreachable!("attempting to render a predicate form wildcard, but no predicate form was compiled for it")),
_ => self
.substitutions
.get(wildcard)
.map_or_else(err, |s| Ok(s.clone())),
}
}
pub fn make_assertion_from_constraint(&self, constraint: &Constraint) -> Result<Expression> {
match constraint {
Constraint::AnyI32 {
variable,
any_values,
} => {
let where_ex = any_values
.iter()
.map(|value| format!("{variable} == {value}"))
.join(" || ");
Ok(Expression::MacroCall("static_assert".to_string(), where_ex))
}
Constraint::RangeI32 {
variable,
range: SizeMatchable::Matched(range),
} => Ok(Expression::MacroCall(
"static_assert_range".to_string(),
format!(
"{variable}, {min}, {max}",
min = range.start(),
max = range.end()
),
)),
Constraint::SVEMaxElems {
variable,
sve_max_elems_type: ty,
}
| Constraint::VecMaxElems {
variable,
vec_max_elems_type: ty,
} => {
if !self.input.is_empty() {
let higher_limit = match constraint {
Constraint::SVEMaxElems { .. } => SVE_VECTOR_MAX_SIZE,
Constraint::VecMaxElems { .. } => VECTOR_REG_SIZE,
_ => unreachable!(),
};
let max = ty.base_type()
.map(|ty| ty.get_size())
.transpose()?
.map_or_else(
|| Err(format!("can't make an assertion out of constraint {self:?}: no valid type is present")),
|bitsize| Ok(higher_limit / bitsize - 1))?;
Ok(Expression::MacroCall(
"static_assert_range".to_string(),
format!("{variable}, 0, {max}"),
))
} else {
Err(format!("can't make an assertion out of constraint {self:?}: no types are being used"))
}
}
_ => unreachable!("constraints were not built successfully!"),
}
}
pub fn predicate_form(&self) -> Option<&PredicateForm> {
self.input.iter().find_map(|arg| arg.predicate_form())
}
pub fn n_variant_op(&self) -> Option<&WildString> {
self.input.iter().find_map(|arg| arg.n_variant_op())
}
}
pub struct Context<'ctx> {
pub local: &'ctx mut LocalContext,
pub group: &'ctx mut GroupContext,
pub global: &'ctx GlobalContext,
}

View file

@ -0,0 +1,546 @@
use itertools::Itertools;
use lazy_static::lazy_static;
use proc_macro2::{Literal, TokenStream};
use quote::{format_ident, quote, ToTokens, TokenStreamExt};
use regex::Regex;
use serde::de::{self, MapAccess, Visitor};
use serde::{Deserialize, Deserializer, Serialize};
use std::fmt;
use std::str::FromStr;
use crate::intrinsic::Intrinsic;
use crate::{
context::{self, Context, VariableType},
intrinsic::{Argument, LLVMLink, StaticDefinition},
matching::{MatchKindValues, MatchSizeValues},
typekinds::{BaseType, BaseTypeKind, TypeKind},
wildcards::Wildcard,
wildstring::WildString,
};
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum IdentifierType {
Variable,
Symbol,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum LetVariant {
Basic(WildString, Box<Expression>),
WithType(WildString, TypeKind, Box<Expression>),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FnCall(
/// Function pointer
pub Box<Expression>,
/// Function arguments
pub Vec<Expression>,
/// Function turbofish arguments
#[serde(default)]
pub Vec<Expression>,
);
impl FnCall {
pub fn new_expression(fn_ptr: Expression, arguments: Vec<Expression>) -> Expression {
FnCall(Box::new(fn_ptr), arguments, Vec::new()).into()
}
pub fn is_llvm_link_call(&self, llvm_link_name: &String) -> bool {
if let Expression::Identifier(fn_name, IdentifierType::Symbol) = self.0.as_ref() {
&fn_name.to_string() == llvm_link_name
} else {
false
}
}
pub fn pre_build(&mut self, ctx: &mut Context) -> context::Result {
self.0.pre_build(ctx)?;
self.1
.iter_mut()
.chain(self.2.iter_mut())
.try_for_each(|ex| ex.pre_build(ctx))
}
pub fn build(&mut self, intrinsic: &Intrinsic, ctx: &mut Context) -> context::Result {
self.0.build(intrinsic, ctx)?;
self.1
.iter_mut()
.chain(self.2.iter_mut())
.try_for_each(|ex| ex.build(intrinsic, ctx))
}
}
impl ToTokens for FnCall {
fn to_tokens(&self, tokens: &mut TokenStream) {
let FnCall(fn_ptr, arguments, turbofish) = self;
fn_ptr.to_tokens(tokens);
if !turbofish.is_empty() {
tokens.append_all(quote! {::<#(#turbofish),*>});
}
tokens.append_all(quote! { (#(#arguments),*) })
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(remote = "Self", deny_unknown_fields)]
pub enum Expression {
/// (Re)Defines a variable
Let(LetVariant),
/// Performs a variable assignment operation
Assign(String, Box<Expression>),
/// Performs a macro call
MacroCall(String, String),
/// Performs a function call
FnCall(FnCall),
/// Performs a method call. The following:
/// `MethodCall: ["$object", "to_string", []]`
/// is tokenized as:
/// `object.to_string()`.
MethodCall(Box<Expression>, String, Vec<Expression>),
/// Symbol identifier name, prepend with a `$` to treat it as a scope variable
/// which engages variable tracking and enables inference.
/// E.g. `my_function_name` for a generic symbol or `$my_variable` for
/// a variable.
Identifier(WildString, IdentifierType),
/// Constant signed integer number expression
IntConstant(i32),
/// Constant floating point number expression
FloatConstant(f32),
/// Constant boolean expression, either `true` or `false`
BoolConstant(bool),
/// Array expression
Array(Vec<Expression>),
// complex expressions
/// Makes an LLVM link.
///
/// It stores the link's function name in the wildcard `{llvm_link}`, for use in
/// subsequent expressions.
LLVMLink(LLVMLink),
/// Casts the given expression to the specified (unchecked) type
CastAs(Box<Expression>, String),
/// Returns the LLVM `undef` symbol
SvUndef,
/// Multiplication
Multiply(Box<Expression>, Box<Expression>),
/// Converts the specified constant to the specified type's kind
ConvertConst(TypeKind, i32),
/// Yields the given type in the Rust representation
Type(TypeKind),
MatchSize(TypeKind, MatchSizeValues<Box<Expression>>),
MatchKind(TypeKind, MatchKindValues<Box<Expression>>),
}
impl Expression {
pub fn pre_build(&mut self, ctx: &mut Context) -> context::Result {
match self {
Self::FnCall(fn_call) => fn_call.pre_build(ctx),
Self::MethodCall(cl_ptr_ex, _, arg_exs) => {
cl_ptr_ex.pre_build(ctx)?;
arg_exs.iter_mut().try_for_each(|ex| ex.pre_build(ctx))
}
Self::Let(LetVariant::Basic(_, ex) | LetVariant::WithType(_, _, ex)) => {
ex.pre_build(ctx)
}
Self::CastAs(ex, _) => ex.pre_build(ctx),
Self::Multiply(lhs, rhs) => {
lhs.pre_build(ctx)?;
rhs.pre_build(ctx)
}
Self::MatchSize(match_ty, values) => {
*self = *values.get(match_ty, ctx.local)?.to_owned();
self.pre_build(ctx)
}
Self::MatchKind(match_ty, values) => {
*self = *values.get(match_ty, ctx.local)?.to_owned();
self.pre_build(ctx)
}
_ => Ok(()),
}
}
pub fn build(&mut self, intrinsic: &Intrinsic, ctx: &mut Context) -> context::Result {
match self {
Self::LLVMLink(link) => link.build_and_save(ctx),
Self::Identifier(identifier, id_type) => {
identifier.build_acle(ctx.local)?;
if let IdentifierType::Variable = id_type {
ctx.local
.variables
.get(&identifier.to_string())
.map(|_| ())
.ok_or_else(|| format!("invalid variable {identifier} being referenced"))
} else {
Ok(())
}
}
Self::FnCall(fn_call) => {
fn_call.build(intrinsic, ctx)?;
if let Some(llvm_link_name) = ctx.local.substitutions.get(&Wildcard::LLVMLink) {
if fn_call.is_llvm_link_call(llvm_link_name) {
*self = intrinsic
.llvm_link()
.expect("got LLVMLink wildcard without a LLVM link in `compose`")
.apply_conversions_to_call(fn_call.clone(), ctx.local)?
}
}
Ok(())
}
Self::MethodCall(cl_ptr_ex, _, arg_exs) => {
cl_ptr_ex.build(intrinsic, ctx)?;
arg_exs
.iter_mut()
.try_for_each(|ex| ex.build(intrinsic, ctx))
}
Self::Let(variant) => {
let (var_name, ex, ty) = match variant {
LetVariant::Basic(var_name, ex) => (var_name, ex, None),
LetVariant::WithType(var_name, ty, ex) => {
if let Some(w) = ty.wildcard() {
ty.populate_wildcard(ctx.local.provide_type_wildcard(w)?)?;
}
(var_name, ex, Some(ty.to_owned()))
}
};
var_name.build_acle(ctx.local)?;
ctx.local.variables.insert(
var_name.to_string(),
(
ty.unwrap_or_else(|| TypeKind::Custom("unknown".to_string())),
VariableType::Internal,
),
);
ex.build(intrinsic, ctx)
}
Self::CastAs(ex, _) => ex.build(intrinsic, ctx),
Self::Multiply(lhs, rhs) => {
lhs.build(intrinsic, ctx)?;
rhs.build(intrinsic, ctx)
}
Self::ConvertConst(ty, num) => {
if let Some(w) = ty.wildcard() {
*ty = ctx.local.provide_type_wildcard(w)?
}
if let Some(BaseType::Sized(BaseTypeKind::Float, _)) = ty.base() {
*self = Expression::FloatConstant(*num as f32)
} else {
*self = Expression::IntConstant(*num)
}
Ok(())
}
Self::Type(ty) => {
if let Some(w) = ty.wildcard() {
*ty = ctx.local.provide_type_wildcard(w)?
}
Ok(())
}
_ => Ok(()),
}
}
/// True if the expression requires an `unsafe` context in a safe function.
///
/// The classification is somewhat fuzzy, based on actual usage (e.g. empirical function names)
/// rather than a full parse. This is a reasonable approach because mistakes here will usually
/// be caught at build time:
///
/// - Missing an `unsafe` is a build error.
/// - An unnecessary `unsafe` is a warning, made into an error by the CI's `-D warnings`.
///
/// This **panics** if it encounters an expression that shouldn't appear in a safe function at
/// all (such as `SvUndef`).
pub fn requires_unsafe_wrapper(&self, ctx_fn: &str) -> bool {
match self {
// The call will need to be unsafe, but the declaration does not.
Self::LLVMLink(..) => false,
// Identifiers, literals and type names are never unsafe.
Self::Identifier(..) => false,
Self::IntConstant(..) => false,
Self::FloatConstant(..) => false,
Self::BoolConstant(..) => false,
Self::Type(..) => false,
Self::ConvertConst(..) => false,
// Nested structures that aren't inherently unsafe, but could contain other expressions
// that might be.
Self::Assign(_var, exp) => exp.requires_unsafe_wrapper(ctx_fn),
Self::Let(LetVariant::Basic(_, exp) | LetVariant::WithType(_, _, exp)) => {
exp.requires_unsafe_wrapper(ctx_fn)
}
Self::Array(exps) => exps.iter().any(|exp| exp.requires_unsafe_wrapper(ctx_fn)),
Self::Multiply(lhs, rhs) => {
lhs.requires_unsafe_wrapper(ctx_fn) || rhs.requires_unsafe_wrapper(ctx_fn)
}
Self::CastAs(exp, _ty) => exp.requires_unsafe_wrapper(ctx_fn),
// Functions and macros can be unsafe, but can also contain other expressions.
Self::FnCall(FnCall(fn_exp, args, turbo_args)) => {
let fn_name = fn_exp.to_string();
fn_exp.requires_unsafe_wrapper(ctx_fn)
|| fn_name.starts_with("_sv")
|| fn_name.starts_with("simd_")
|| fn_name.ends_with("transmute")
|| args.iter().any(|exp| exp.requires_unsafe_wrapper(ctx_fn))
|| turbo_args
.iter()
.any(|exp| exp.requires_unsafe_wrapper(ctx_fn))
}
Self::MethodCall(exp, fn_name, args) => match fn_name.as_str() {
// `as_signed` and `as_unsigned` are unsafe because they're trait methods with
// target features to allow use on feature-dependent types (such as SVE vectors).
// We can safely wrap them here.
"as_signed" => true,
"as_unsigned" => true,
_ => {
exp.requires_unsafe_wrapper(ctx_fn)
|| args.iter().any(|exp| exp.requires_unsafe_wrapper(ctx_fn))
}
},
// We only use macros to check const generics (using static assertions).
Self::MacroCall(_name, _args) => false,
// Materialising uninitialised values is always unsafe, and we avoid it in safe
// functions.
Self::SvUndef => panic!("Refusing to wrap unsafe SvUndef in safe function '{ctx_fn}'."),
// Variants that aren't tokenised. We shouldn't encounter these here.
Self::MatchKind(..) => {
unimplemented!("The unsafety of {self:?} cannot be determined in '{ctx_fn}'.")
}
Self::MatchSize(..) => {
unimplemented!("The unsafety of {self:?} cannot be determined in '{ctx_fn}'.")
}
}
}
}
impl FromStr for Expression {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
lazy_static! {
static ref MACRO_RE: Regex =
Regex::new(r"^(?P<name>[\w\d_]+)!\((?P<ex>.*?)\);?$").unwrap();
}
if s == "SvUndef" {
Ok(Expression::SvUndef)
} else if MACRO_RE.is_match(s) {
let c = MACRO_RE.captures(s).unwrap();
let ex = c["ex"].to_string();
let _: TokenStream = ex
.parse()
.map_err(|e| format!("could not parse macro call expression: {e:#?}"))?;
Ok(Expression::MacroCall(c["name"].to_string(), ex))
} else {
let (s, id_type) = if let Some(varname) = s.strip_prefix('$') {
(varname, IdentifierType::Variable)
} else {
(s, IdentifierType::Symbol)
};
let identifier = s.trim().parse()?;
Ok(Expression::Identifier(identifier, id_type))
}
}
}
impl From<FnCall> for Expression {
fn from(fn_call: FnCall) -> Self {
Expression::FnCall(fn_call)
}
}
impl From<WildString> for Expression {
fn from(ws: WildString) -> Self {
Expression::Identifier(ws, IdentifierType::Symbol)
}
}
impl From<&Argument> for Expression {
fn from(a: &Argument) -> Self {
Expression::Identifier(a.name.to_owned(), IdentifierType::Variable)
}
}
impl TryFrom<&StaticDefinition> for Expression {
type Error = String;
fn try_from(sd: &StaticDefinition) -> Result<Self, Self::Error> {
match sd {
StaticDefinition::Constant(imm) => Ok(imm.into()),
StaticDefinition::Generic(t) => t.parse(),
}
}
}
impl fmt::Display for Expression {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Identifier(identifier, kind) => {
write!(
f,
"{}{identifier}",
matches!(kind, IdentifierType::Variable)
.then_some("$")
.unwrap_or_default()
)
}
Self::MacroCall(name, expression) => {
write!(f, "{name}!({expression})")
}
_ => Err(fmt::Error),
}
}
}
impl ToTokens for Expression {
fn to_tokens(&self, tokens: &mut TokenStream) {
match self {
Self::Let(LetVariant::Basic(var_name, exp)) => {
let var_ident = format_ident!("{}", var_name.to_string());
tokens.append_all(quote! { let #var_ident = #exp })
}
Self::Let(LetVariant::WithType(var_name, ty, exp)) => {
let var_ident = format_ident!("{}", var_name.to_string());
tokens.append_all(quote! { let #var_ident: #ty = #exp })
}
Self::Assign(var_name, exp) => {
let var_ident = format_ident!("{}", var_name);
tokens.append_all(quote! { #var_ident = #exp })
}
Self::MacroCall(name, ex) => {
let name = format_ident!("{name}");
let ex: TokenStream = ex.parse().unwrap();
tokens.append_all(quote! { #name!(#ex) })
}
Self::FnCall(fn_call) => fn_call.to_tokens(tokens),
Self::MethodCall(exp, fn_name, args) => {
let fn_ident = format_ident!("{}", fn_name);
tokens.append_all(quote! { #exp.#fn_ident(#(#args),*) })
}
Self::Identifier(identifier, _) => {
assert!(
!identifier.has_wildcards(),
"expression {self:#?} was not built before calling to_tokens"
);
identifier
.to_string()
.parse::<TokenStream>()
.expect("invalid syntax")
.to_tokens(tokens);
}
Self::IntConstant(n) => tokens.append(Literal::i32_unsuffixed(*n)),
Self::FloatConstant(n) => tokens.append(Literal::f32_unsuffixed(*n)),
Self::BoolConstant(true) => tokens.append(format_ident!("true")),
Self::BoolConstant(false) => tokens.append(format_ident!("false")),
Self::Array(vec) => tokens.append_all(quote! { [ #(#vec),* ] }),
Self::LLVMLink(link) => link.to_tokens(tokens),
Self::CastAs(ex, ty) => {
let ty: TokenStream = ty.parse().expect("invalid syntax");
tokens.append_all(quote! { #ex as #ty })
}
Self::SvUndef => tokens.append_all(quote! { simd_reinterpret(()) }),
Self::Multiply(lhs, rhs) => tokens.append_all(quote! { #lhs * #rhs }),
Self::Type(ty) => ty.to_tokens(tokens),
_ => unreachable!("{self:?} cannot be converted to tokens."),
}
}
}
impl Serialize for Expression {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
match self {
Self::IntConstant(v) => serializer.serialize_i32(*v),
Self::FloatConstant(v) => serializer.serialize_f32(*v),
Self::BoolConstant(v) => serializer.serialize_bool(*v),
Self::Identifier(..) => serializer.serialize_str(&self.to_string()),
Self::MacroCall(..) => serializer.serialize_str(&self.to_string()),
_ => Expression::serialize(self, serializer),
}
}
}
impl<'de> Deserialize<'de> for Expression {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct CustomExpressionVisitor;
impl<'de> Visitor<'de> for CustomExpressionVisitor {
type Value = Expression;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("integer, float, boolean, string or map")
}
fn visit_bool<E>(self, v: bool) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(Expression::BoolConstant(v))
}
fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(Expression::IntConstant(v as i32))
}
fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(Expression::IntConstant(v as i32))
}
fn visit_f64<E>(self, v: f64) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(Expression::FloatConstant(v as f32))
}
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
FromStr::from_str(value).map_err(de::Error::custom)
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: de::SeqAccess<'de>,
{
let arr = std::iter::from_fn(|| seq.next_element::<Self::Value>().transpose())
.try_collect()?;
Ok(Expression::Array(arr))
}
fn visit_map<M>(self, map: M) -> Result<Expression, M::Error>
where
M: MapAccess<'de>,
{
// `MapAccessDeserializer` is a wrapper that turns a `MapAccess`
// into a `Deserializer`, allowing it to be used as the input to T's
// `Deserialize` implementation. T then deserializes itself using
// the entries from the map visitor.
Expression::deserialize(de::value::MapAccessDeserializer::new(map))
}
}
deserializer.deserialize_any(CustomExpressionVisitor)
}
}

View file

@ -0,0 +1,432 @@
use itertools::Itertools;
use serde::{de, Deserialize, Deserializer, Serialize};
use crate::{
context::{self, GlobalContext},
intrinsic::Intrinsic,
predicate_forms::{PredicateForm, PredicationMask, PredicationMethods},
typekinds::TypeKind,
wildstring::WildString,
};
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(untagged)]
pub enum InputType {
/// PredicateForm variant argument
#[serde(skip)] // Predicate forms have their own dedicated deserialization field. Skip.
PredicateForm(PredicateForm),
/// Operand from which to generate an N variant
#[serde(skip)]
NVariantOp(Option<WildString>),
/// TypeKind variant argument
Type(TypeKind),
}
impl InputType {
/// Optionally unwraps as a PredicateForm.
pub fn predicate_form(&self) -> Option<&PredicateForm> {
match self {
InputType::PredicateForm(pf) => Some(pf),
_ => None,
}
}
/// Optionally unwraps as a mutable PredicateForm
pub fn predicate_form_mut(&mut self) -> Option<&mut PredicateForm> {
match self {
InputType::PredicateForm(pf) => Some(pf),
_ => None,
}
}
/// Optionally unwraps as a TypeKind.
pub fn typekind(&self) -> Option<&TypeKind> {
match self {
InputType::Type(ty) => Some(ty),
_ => None,
}
}
/// Optionally unwraps as a NVariantOp
pub fn n_variant_op(&self) -> Option<&WildString> {
match self {
InputType::NVariantOp(Some(op)) => Some(op),
_ => None,
}
}
}
impl PartialOrd for InputType {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for InputType {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
use std::cmp::Ordering::*;
match (self, other) {
(InputType::PredicateForm(pf1), InputType::PredicateForm(pf2)) => pf1.cmp(pf2),
(InputType::Type(ty1), InputType::Type(ty2)) => ty1.cmp(ty2),
(InputType::NVariantOp(None), InputType::NVariantOp(Some(..))) => Less,
(InputType::NVariantOp(Some(..)), InputType::NVariantOp(None)) => Greater,
(InputType::NVariantOp(_), InputType::NVariantOp(_)) => Equal,
(InputType::Type(..), InputType::PredicateForm(..)) => Less,
(InputType::PredicateForm(..), InputType::Type(..)) => Greater,
(InputType::Type(..), InputType::NVariantOp(..)) => Less,
(InputType::NVariantOp(..), InputType::Type(..)) => Greater,
(InputType::PredicateForm(..), InputType::NVariantOp(..)) => Less,
(InputType::NVariantOp(..), InputType::PredicateForm(..)) => Greater,
}
}
}
mod many_or_one {
use serde::{de::Deserializer, ser::Serializer, Deserialize, Serialize};
pub fn serialize<T, S>(vec: &Vec<T>, serializer: S) -> Result<S::Ok, S::Error>
where
T: Serialize,
S: Serializer,
{
if vec.len() == 1 {
vec.first().unwrap().serialize(serializer)
} else {
vec.serialize(serializer)
}
}
pub fn deserialize<'de, T, D>(deserializer: D) -> Result<Vec<T>, D::Error>
where
T: Deserialize<'de>,
D: Deserializer<'de>,
{
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
enum ManyOrOne<T> {
Many(Vec<T>),
One(T),
}
match ManyOrOne::deserialize(deserializer)? {
ManyOrOne::Many(vec) => Ok(vec),
ManyOrOne::One(val) => Ok(vec![val]),
}
}
}
#[derive(Debug, Clone, Default, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub struct InputSet(#[serde(with = "many_or_one")] Vec<InputType>);
impl InputSet {
pub fn get(&self, idx: usize) -> Option<&InputType> {
self.0.get(idx)
}
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
pub fn iter(&self) -> impl Iterator<Item = &InputType> + '_ {
self.0.iter()
}
pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut InputType> + '_ {
self.0.iter_mut()
}
pub fn into_iter(self) -> impl Iterator<Item = InputType> + Clone {
self.0.into_iter()
}
pub fn types_len(&self) -> usize {
self.iter().filter_map(|arg| arg.typekind()).count()
}
pub fn typekind(&self, idx: Option<usize>) -> Option<TypeKind> {
let types_len = self.types_len();
self.get(idx.unwrap_or(0)).and_then(move |arg: &InputType| {
if (idx.is_none() && types_len != 1) || (idx.is_some() && types_len == 1) {
None
} else {
arg.typekind().cloned()
}
})
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct InputSetEntry(#[serde(with = "many_or_one")] Vec<InputSet>);
impl InputSetEntry {
pub fn new(input: Vec<InputSet>) -> Self {
Self(input)
}
pub fn get(&self, idx: usize) -> Option<&InputSet> {
self.0.get(idx)
}
}
fn validate_types<'de, D>(deserializer: D) -> Result<Vec<InputSetEntry>, D::Error>
where
D: Deserializer<'de>,
{
let v: Vec<InputSetEntry> = Vec::deserialize(deserializer)?;
let mut it = v.iter();
if let Some(first) = it.next() {
it.try_fold(first, |last, cur| {
if last.0.len() == cur.0.len() {
Ok(cur)
} else {
Err("the length of the InputSets and the product lists must match".to_string())
}
})
.map_err(de::Error::custom)?;
}
Ok(v)
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct IntrinsicInput {
#[serde(default)]
#[serde(deserialize_with = "validate_types")]
pub types: Vec<InputSetEntry>,
#[serde(flatten)]
pub predication_methods: PredicationMethods,
/// Generates a _n variant where the specified operand is a primitive type
/// that requires conversion to an SVE one. The `{_n}` wildcard is required
/// in the intrinsic's name, otherwise an error will be thrown.
#[serde(default)]
pub n_variant_op: WildString,
}
impl IntrinsicInput {
/// Extracts all the possible variants as an iterator.
pub fn variants(
&self,
intrinsic: &Intrinsic,
) -> context::Result<impl Iterator<Item = InputSet> + '_> {
let mut top_product = vec![];
if !self.types.is_empty() {
top_product.push(
self.types
.iter()
.flat_map(|ty_in| {
ty_in
.0
.iter()
.map(|v| v.clone().into_iter())
.multi_cartesian_product()
})
.collect_vec(),
)
}
if let Ok(mask) = PredicationMask::try_from(&intrinsic.signature.name) {
top_product.push(
PredicateForm::compile_list(&mask, &self.predication_methods)?
.into_iter()
.map(|pf| vec![InputType::PredicateForm(pf)])
.collect_vec(),
)
}
if !self.n_variant_op.is_empty() {
top_product.push(vec![
vec![InputType::NVariantOp(None)],
vec![InputType::NVariantOp(Some(self.n_variant_op.to_owned()))],
])
}
let it = top_product
.into_iter()
.map(|v| v.into_iter())
.multi_cartesian_product()
.map(|set| InputSet(set.into_iter().flatten().collect_vec()));
Ok(it)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GeneratorInput {
#[serde(flatten)]
pub ctx: GlobalContext,
pub intrinsics: Vec<Intrinsic>,
}
#[cfg(test)]
mod tests {
use crate::{
input::*,
predicate_forms::{DontCareMethod, ZeroingMethod},
};
#[test]
fn test_empty() {
let str = r#"types: []"#;
let input: IntrinsicInput = serde_yaml::from_str(str).expect("failed to parse");
let mut variants = input.variants(&Intrinsic::default()).unwrap().into_iter();
assert_eq!(variants.next(), None);
}
#[test]
fn test_product() {
let str = r#"types:
- [f64, f32]
- [i64, [f64, f32]]
"#;
let input: IntrinsicInput = serde_yaml::from_str(str).expect("failed to parse");
let mut intrinsic = Intrinsic::default();
intrinsic.signature.name = "test_intrinsic{_mx}".parse().unwrap();
let mut variants = input.variants(&intrinsic).unwrap().into_iter();
assert_eq!(
variants.next(),
Some(InputSet(vec![
InputType::Type("f64".parse().unwrap()),
InputType::Type("f32".parse().unwrap()),
InputType::PredicateForm(PredicateForm::Merging),
]))
);
assert_eq!(
variants.next(),
Some(InputSet(vec![
InputType::Type("f64".parse().unwrap()),
InputType::Type("f32".parse().unwrap()),
InputType::PredicateForm(PredicateForm::DontCare(DontCareMethod::AsMerging)),
]))
);
assert_eq!(
variants.next(),
Some(InputSet(vec![
InputType::Type("i64".parse().unwrap()),
InputType::Type("f64".parse().unwrap()),
InputType::PredicateForm(PredicateForm::Merging),
]))
);
assert_eq!(
variants.next(),
Some(InputSet(vec![
InputType::Type("i64".parse().unwrap()),
InputType::Type("f64".parse().unwrap()),
InputType::PredicateForm(PredicateForm::DontCare(DontCareMethod::AsMerging)),
]))
);
assert_eq!(
variants.next(),
Some(InputSet(vec![
InputType::Type("i64".parse().unwrap()),
InputType::Type("f32".parse().unwrap()),
InputType::PredicateForm(PredicateForm::Merging),
]))
);
assert_eq!(
variants.next(),
Some(InputSet(vec![
InputType::Type("i64".parse().unwrap()),
InputType::Type("f32".parse().unwrap()),
InputType::PredicateForm(PredicateForm::DontCare(DontCareMethod::AsMerging)),
])),
);
assert_eq!(variants.next(), None);
}
#[test]
fn test_n_variant() {
let str = r#"types:
- [f64, f32]
n_variant_op: op2
"#;
let input: IntrinsicInput = serde_yaml::from_str(str).expect("failed to parse");
let mut variants = input.variants(&Intrinsic::default()).unwrap().into_iter();
assert_eq!(
variants.next(),
Some(InputSet(vec![
InputType::Type("f64".parse().unwrap()),
InputType::Type("f32".parse().unwrap()),
InputType::NVariantOp(None),
]))
);
assert_eq!(
variants.next(),
Some(InputSet(vec![
InputType::Type("f64".parse().unwrap()),
InputType::Type("f32".parse().unwrap()),
InputType::NVariantOp(Some("op2".parse().unwrap())),
]))
);
assert_eq!(variants.next(), None)
}
#[test]
fn test_invalid_length() {
let str = r#"types: [i32, [[u64], [u32]]]"#;
serde_yaml::from_str::<IntrinsicInput>(str).expect_err("failure expected");
}
#[test]
fn test_invalid_predication() {
let str = "types: []";
let input: IntrinsicInput = serde_yaml::from_str(str).expect("failed to parse");
let mut intrinsic = Intrinsic::default();
intrinsic.signature.name = "test_intrinsic{_mxz}".parse().unwrap();
input
.variants(&intrinsic)
.map(|v| v.collect_vec())
.expect_err("failure expected");
}
#[test]
fn test_invalid_predication_mask() {
"test_intrinsic{_mxy}"
.parse::<WildString>()
.expect_err("failure expected");
"test_intrinsic{_}"
.parse::<WildString>()
.expect_err("failure expected");
}
#[test]
fn test_zeroing_predication() {
let str = r#"types: [i64]
zeroing_method: { drop: inactive }"#;
let input: IntrinsicInput = serde_yaml::from_str(str).expect("failed to parse");
let mut intrinsic = Intrinsic::default();
intrinsic.signature.name = "test_intrinsic{_mxz}".parse().unwrap();
let mut variants = input.variants(&intrinsic).unwrap();
assert_eq!(
variants.next(),
Some(InputSet(vec![
InputType::Type("i64".parse().unwrap()),
InputType::PredicateForm(PredicateForm::Merging),
]))
);
assert_eq!(
variants.next(),
Some(InputSet(vec![
InputType::Type("i64".parse().unwrap()),
InputType::PredicateForm(PredicateForm::DontCare(DontCareMethod::AsZeroing)),
]))
);
assert_eq!(
variants.next(),
Some(InputSet(vec![
InputType::Type("i64".parse().unwrap()),
InputType::PredicateForm(PredicateForm::Zeroing(ZeroingMethod::Drop {
drop: "inactive".parse().unwrap()
})),
]))
);
assert_eq!(variants.next(), None)
}
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,818 @@
use std::fs::File;
use std::io::Write;
use std::path::PathBuf;
use std::str::FromStr;
use crate::format_code;
use crate::input::InputType;
use crate::intrinsic::Intrinsic;
use crate::typekinds::BaseType;
use crate::typekinds::{ToRepr, TypeKind};
use itertools::Itertools;
use lazy_static::lazy_static;
use proc_macro2::TokenStream;
use quote::{format_ident, quote};
// Number of vectors in our buffers - the maximum tuple size, 4, plus 1 as we set the vnum
// argument to 1.
const NUM_VECS: usize = 5;
// The maximum vector length (in bits)
const VL_MAX_BITS: usize = 2048;
// The maximum vector length (in bytes)
const VL_MAX_BYTES: usize = VL_MAX_BITS / 8;
// The maximum number of elements in each vector type
const LEN_F32: usize = VL_MAX_BYTES / core::mem::size_of::<f32>();
const LEN_F64: usize = VL_MAX_BYTES / core::mem::size_of::<f64>();
const LEN_I8: usize = VL_MAX_BYTES / core::mem::size_of::<i8>();
const LEN_I16: usize = VL_MAX_BYTES / core::mem::size_of::<i16>();
const LEN_I32: usize = VL_MAX_BYTES / core::mem::size_of::<i32>();
const LEN_I64: usize = VL_MAX_BYTES / core::mem::size_of::<i64>();
const LEN_U8: usize = VL_MAX_BYTES / core::mem::size_of::<u8>();
const LEN_U16: usize = VL_MAX_BYTES / core::mem::size_of::<u16>();
const LEN_U32: usize = VL_MAX_BYTES / core::mem::size_of::<u32>();
const LEN_U64: usize = VL_MAX_BYTES / core::mem::size_of::<u64>();
/// `load_intrinsics` and `store_intrinsics` is a vector of intrinsics
/// variants, while `out_path` is a file to write to.
pub fn generate_load_store_tests(
load_intrinsics: Vec<Intrinsic>,
store_intrinsics: Vec<Intrinsic>,
out_path: Option<&PathBuf>,
) -> Result<(), String> {
let output = match out_path {
Some(out) => {
Box::new(File::create(out).map_err(|e| format!("couldn't create tests file: {e}"))?)
as Box<dyn Write>
}
None => Box::new(std::io::stdout()) as Box<dyn Write>,
};
let mut used_stores = vec![false; store_intrinsics.len()];
let tests: Vec<_> = load_intrinsics
.iter()
.map(|load| {
let store_candidate = load
.signature
.fn_name()
.to_string()
.replace("svld1s", "svst1")
.replace("svld1u", "svst1")
.replace("svldnt1s", "svstnt1")
.replace("svldnt1u", "svstnt1")
.replace("svld", "svst")
.replace("gather", "scatter");
let store_index = store_intrinsics
.iter()
.position(|i| i.signature.fn_name().to_string() == store_candidate);
if let Some(i) = store_index {
used_stores[i] = true;
}
generate_single_test(
load.clone(),
store_index.map(|i| store_intrinsics[i].clone()),
)
})
.try_collect()?;
assert!(used_stores.into_iter().all(|b| b), "Not all store tests have been paired with a load. Consider generating specifc store-only tests");
let preamble =
TokenStream::from_str(&PREAMBLE).map_err(|e| format!("Preamble is invalid: {e}"))?;
// Only output manual tests for the SVE set
let manual_tests = match &load_intrinsics[0].target_features[..] {
[s] if s == "sve" => TokenStream::from_str(&MANUAL_TESTS)
.map_err(|e| format!("Manual tests are invalid: {e}"))?,
_ => quote!(),
};
format_code(
output,
format!(
"// This code is automatically generated. DO NOT MODIFY.
//
// Instead, modify `crates/stdarch-gen2/spec/sve` and run the following command to re-generate this
// file:
//
// ```
// cargo run --bin=stdarch-gen2 -- crates/stdarch-gen2/spec
// ```
{}",
quote! { #preamble #(#tests)* #manual_tests }
),
)
.map_err(|e| format!("couldn't write tests: {e}"))
}
/// A test looks like this:
/// ```
/// let data = [scalable vector];
///
/// let mut storage = [0; N];
///
/// store_intrinsic([true_predicate], storage.as_mut_ptr(), data);
/// [test contents of storage]
///
/// let loaded == load_intrinsic([true_predicate], storage.as_ptr())
/// assert!(loaded == data);
/// ```
/// We intialise our data such that the value stored matches the index it's stored to.
/// By doing this we can validate scatters by checking that each value in the storage
/// array is either 0 or the same as its index.
fn generate_single_test(
load: Intrinsic,
store: Option<Intrinsic>,
) -> Result<proc_macro2::TokenStream, String> {
let chars = LdIntrCharacteristics::new(&load)?;
let fn_name = load.signature.fn_name().to_string();
if let Some(ty) = &chars.gather_bases_type {
if ty.base_type().unwrap().get_size() == Ok(32)
&& chars.gather_index_type.is_none()
&& chars.gather_offset_type.is_none()
{
// We lack a way to ensure data is in the bottom 32 bits of the address space
println!("Skipping test for {fn_name}");
return Ok(quote!());
}
}
if fn_name.starts_with("svldff1") && fn_name.contains("gather") {
// TODO: We can remove this check when first-faulting gathers are fixed in CI's QEMU
// https://gitlab.com/qemu-project/qemu/-/issues/1612
println!("Skipping test for {fn_name}");
return Ok(quote!());
}
let fn_ident = format_ident!("{fn_name}");
let test_name = format_ident!(
"test_{fn_name}{}",
if let Some(ref store) = store {
format!("_with_{}", store.signature.fn_name())
} else {
String::new()
}
);
let load_type = &chars.load_type;
let acle_type = load_type.acle_notation_repr();
// If there's no return type, fallback to the load type for things that depend on it
let ret_type = &load
.signature
.return_type
.as_ref()
.and_then(TypeKind::base_type)
.unwrap_or(load_type);
let pred_fn = format_ident!("svptrue_b{}", load_type.size());
let load_type_caps = load_type.rust_repr().to_uppercase();
let data_array = format_ident!("{load_type_caps}_DATA");
let size_fn = format_ident!("svcnt{}", ret_type.size_literal());
let rust_ret_type = ret_type.rust_repr();
let assert_fn = format_ident!("assert_vector_matches_{rust_ret_type}");
// Use vnum=1, so adjust all values by one vector length
let (length_call, vnum_arg) = if chars.vnum {
if chars.is_prf {
(quote!(), quote!(, 1))
} else {
(quote!(let len = #size_fn() as usize;), quote!(, 1))
}
} else {
(quote!(), quote!())
};
let (bases_load, bases_arg) = if let Some(ty) = &chars.gather_bases_type {
// Bases is a vector of (sometimes 32-bit) pointers
// When we combine bases with an offset/index argument, we load from the data arrays
// starting at 1
let base_ty = ty.base_type().unwrap();
let rust_type = format_ident!("{}", base_ty.rust_repr());
let index_fn = format_ident!("svindex_{}", base_ty.acle_notation_repr());
let size_in_bytes = chars.load_type.get_size().unwrap() / 8;
if base_ty.get_size().unwrap() == 32 {
// Treat bases as a vector of offsets here - we don't test this without an offset or
// index argument
(
Some(quote!(
let bases = #index_fn(0, #size_in_bytes.try_into().unwrap());
)),
quote!(, bases),
)
} else {
// Treat bases as a vector of pointers
let base_fn = format_ident!("svdup_n_{}", base_ty.acle_notation_repr());
let data_array = if store.is_some() {
format_ident!("storage")
} else {
format_ident!("{}_DATA", chars.load_type.rust_repr().to_uppercase())
};
let add_fn = format_ident!("svadd_{}_x", base_ty.acle_notation_repr());
(
Some(quote! {
let bases = #base_fn(#data_array.as_ptr() as #rust_type);
let offsets = #index_fn(0, #size_in_bytes.try_into().unwrap());
let bases = #add_fn(#pred_fn(), bases, offsets);
}),
quote!(, bases),
)
}
} else {
(None, quote!())
};
let index_arg = if let Some(ty) = &chars.gather_index_type {
let rust_type = format_ident!("{}", ty.rust_repr());
if chars
.gather_bases_type
.as_ref()
.and_then(TypeKind::base_type)
.map_or(Err(String::new()), BaseType::get_size)
.unwrap()
== 32
{
// Let index be the base of the data array
let data_array = if store.is_some() {
format_ident!("storage")
} else {
format_ident!("{}_DATA", chars.load_type.rust_repr().to_uppercase())
};
let size_in_bytes = chars.load_type.get_size().unwrap() / 8;
quote!(, #data_array.as_ptr() as #rust_type / (#size_in_bytes as #rust_type) + 1)
} else {
quote!(, 1.try_into().unwrap())
}
} else {
quote!()
};
let offset_arg = if let Some(ty) = &chars.gather_offset_type {
let size_in_bytes = chars.load_type.get_size().unwrap() / 8;
if chars
.gather_bases_type
.as_ref()
.and_then(TypeKind::base_type)
.map_or(Err(String::new()), BaseType::get_size)
.unwrap()
== 32
{
// Let offset be the base of the data array
let rust_type = format_ident!("{}", ty.rust_repr());
let data_array = if store.is_some() {
format_ident!("storage")
} else {
format_ident!("{}_DATA", chars.load_type.rust_repr().to_uppercase())
};
quote!(, #data_array.as_ptr() as #rust_type + #size_in_bytes as #rust_type)
} else {
quote!(, #size_in_bytes.try_into().unwrap())
}
} else {
quote!()
};
let (offsets_load, offsets_arg) = if let Some(ty) = &chars.gather_offsets_type {
// Offsets is a scalable vector of per-element offsets in bytes. We re-use the contiguous
// data for this, then multiply to get indices
let offsets_fn = format_ident!("svindex_{}", ty.base_type().unwrap().acle_notation_repr());
let size_in_bytes = chars.load_type.get_size().unwrap() / 8;
(
Some(quote! {
let offsets = #offsets_fn(0, #size_in_bytes.try_into().unwrap());
}),
quote!(, offsets),
)
} else {
(None, quote!())
};
let (indices_load, indices_arg) = if let Some(ty) = &chars.gather_indices_type {
// There's no need to multiply indices by the load type width
let base_ty = ty.base_type().unwrap();
let indices_fn = format_ident!("svindex_{}", base_ty.acle_notation_repr());
(
Some(quote! {
let indices = #indices_fn(0, 1);
}),
quote! {, indices},
)
} else {
(None, quote!())
};
let ptr = if chars.gather_bases_type.is_some() {
quote!()
} else if chars.is_prf {
quote!(, I64_DATA.as_ptr())
} else {
quote!(, #data_array.as_ptr())
};
let tuple_len = &chars.tuple_len;
let expecteds = if chars.is_prf {
// No return value for prefetches
vec![]
} else {
(0..*tuple_len)
.map(|i| get_expected_range(i, &chars))
.collect()
};
let asserts: Vec<_> =
if *tuple_len > 1 {
let svget = format_ident!("svget{tuple_len}_{acle_type}");
expecteds.iter().enumerate().map(|(i, expected)| {
quote! (#assert_fn(#svget::<{ #i as i32 }>(loaded), #expected);)
}).collect()
} else {
expecteds
.iter()
.map(|expected| quote! (#assert_fn(loaded, #expected);))
.collect()
};
let function = if chars.is_prf {
if fn_name.contains("gather") && fn_name.contains("base") && !fn_name.starts_with("svprf_")
{
// svprf(b|h|w|d)_gather base intrinsics do not have a generic type parameter
quote!(#fn_ident::<{ svprfop::SV_PLDL1KEEP }>)
} else {
quote!(#fn_ident::<{ svprfop::SV_PLDL1KEEP }, i64>)
}
} else {
quote!(#fn_ident)
};
let octaword_guard = if chars.replicate_width == Some(256) {
let msg = format!("Skipping {test_name} due to SVE vector length");
quote! {
if svcntb() < 32 {
println!(#msg);
return;
}
}
} else {
quote!()
};
let feats = load.target_features.join(",");
if let Some(store) = store {
let data_init = if *tuple_len == 1 {
quote!(#(#expecteds)*)
} else {
let create = format_ident!("svcreate{tuple_len}_{acle_type}");
quote!(#create(#(#expecteds),*))
};
let input = store.input.types.get(0).unwrap().get(0).unwrap();
let store_type = input
.get(store.test.get_typeset_index().unwrap())
.and_then(InputType::typekind)
.and_then(TypeKind::base_type)
.unwrap();
let store_type = format_ident!("{}", store_type.rust_repr());
let storage_len = NUM_VECS * VL_MAX_BITS / chars.load_type.get_size()? as usize;
let store_fn = format_ident!("{}", store.signature.fn_name().to_string());
let load_type = format_ident!("{}", chars.load_type.rust_repr());
let (store_ptr, store_mut_ptr) = if chars.gather_bases_type.is_none() {
(
quote!(, storage.as_ptr() as *const #load_type),
quote!(, storage.as_mut_ptr()),
)
} else {
(quote!(), quote!())
};
let args = quote!(#pred_fn() #store_ptr #vnum_arg #bases_arg #offset_arg #index_arg #offsets_arg #indices_arg);
let call = if chars.uses_ffr {
// Doing a normal load first maximises the number of elements our ff/nf test loads
let non_ffr_fn_name = format_ident!(
"{}",
fn_name
.replace("svldff1", "svld1")
.replace("svldnf1", "svld1")
);
quote! {
svsetffr();
let _ = #non_ffr_fn_name(#args);
let loaded = #function(#args);
}
} else {
// Note that the FFR must be set for all tests as the assert functions mask against it
quote! {
svsetffr();
let loaded = #function(#args);
}
};
Ok(quote! {
#[simd_test(enable = #feats)]
unsafe fn #test_name() {
#octaword_guard
#length_call
let mut storage = [0 as #store_type; #storage_len];
let data = #data_init;
#bases_load
#offsets_load
#indices_load
#store_fn(#pred_fn() #store_mut_ptr #vnum_arg #bases_arg #offset_arg #index_arg #offsets_arg #indices_arg, data);
for (i, &val) in storage.iter().enumerate() {
assert!(val == 0 as #store_type || val == i as #store_type);
}
#call
#(#asserts)*
}
})
} else {
let args = quote!(#pred_fn() #ptr #vnum_arg #bases_arg #offset_arg #index_arg #offsets_arg #indices_arg);
let call = if chars.uses_ffr {
// Doing a normal load first maximises the number of elements our ff/nf test loads
let non_ffr_fn_name = format_ident!(
"{}",
fn_name
.replace("svldff1", "svld1")
.replace("svldnf1", "svld1")
);
quote! {
svsetffr();
let _ = #non_ffr_fn_name(#args);
let loaded = #function(#args);
}
} else {
// Note that the FFR must be set for all tests as the assert functions mask against it
quote! {
svsetffr();
let loaded = #function(#args);
}
};
Ok(quote! {
#[simd_test(enable = #feats)]
unsafe fn #test_name() {
#octaword_guard
#bases_load
#offsets_load
#indices_load
#call
#length_call
#(#asserts)*
}
})
}
}
/// Assumes chars.ret_type is not None
fn get_expected_range(tuple_idx: usize, chars: &LdIntrCharacteristics) -> proc_macro2::TokenStream {
// vnum=1
let vnum_adjust = if chars.vnum { quote!(len+) } else { quote!() };
let bases_adjust =
(chars.gather_index_type.is_some() || chars.gather_offset_type.is_some()) as usize;
let tuple_len = chars.tuple_len;
let size = chars
.ret_type
.as_ref()
.and_then(TypeKind::base_type)
.unwrap_or(&chars.load_type)
.get_size()
.unwrap() as usize;
if chars.replicate_width == Some(128) {
// svld1rq
let ty_rust = format_ident!(
"{}",
chars
.ret_type
.as_ref()
.unwrap()
.base_type()
.unwrap()
.rust_repr()
);
let args: Vec<_> = (0..(128 / size)).map(|i| quote!(#i as #ty_rust)).collect();
let dup = format_ident!(
"svdupq_n_{}",
chars.ret_type.as_ref().unwrap().acle_notation_repr()
);
quote!(#dup(#(#args,)*))
} else if chars.replicate_width == Some(256) {
// svld1ro - we use two interleaved svdups to create a repeating 256-bit pattern
let ty_rust = format_ident!(
"{}",
chars
.ret_type
.as_ref()
.unwrap()
.base_type()
.unwrap()
.rust_repr()
);
let ret_acle = chars.ret_type.as_ref().unwrap().acle_notation_repr();
let args: Vec<_> = (0..(128 / size)).map(|i| quote!(#i as #ty_rust)).collect();
let args2: Vec<_> = ((128 / size)..(256 / size))
.map(|i| quote!(#i as #ty_rust))
.collect();
let dup = format_ident!("svdupq_n_{ret_acle}");
let interleave = format_ident!("svtrn1q_{ret_acle}");
quote!(#interleave(#dup(#(#args,)*), #dup(#(#args2,)*)))
} else {
let start = bases_adjust + tuple_idx;
if chars
.ret_type
.as_ref()
.unwrap()
.base_type()
.unwrap()
.is_float()
{
// Use svcvt to create a linear sequence of floats
let cvt_fn = format_ident!("svcvt_f{size}_s{size}_x");
let pred_fn = format_ident!("svptrue_b{size}");
let svindex_fn = format_ident!("svindex_s{size}");
quote! { #cvt_fn(#pred_fn(), #svindex_fn((#vnum_adjust #start).try_into().unwrap(), #tuple_len.try_into().unwrap()))}
} else {
let ret_acle = chars.ret_type.as_ref().unwrap().acle_notation_repr();
let svindex = format_ident!("svindex_{ret_acle}");
quote!(#svindex((#vnum_adjust #start).try_into().unwrap(), #tuple_len.try_into().unwrap()))
}
}
}
struct LdIntrCharacteristics {
// The data type to load from (not necessarily the data type returned)
load_type: BaseType,
// The data type to return (None for unit)
ret_type: Option<TypeKind>,
// The size of tuple to load/store
tuple_len: usize,
// Whether a vnum argument is present
vnum: bool,
// Is the intrinsic first/non-faulting?
uses_ffr: bool,
// Is it a prefetch?
is_prf: bool,
// The size of data loaded with svld1ro/q intrinsics
replicate_width: Option<usize>,
// Scalable vector of pointers to load from
gather_bases_type: Option<TypeKind>,
// Scalar offset, paired with bases
gather_offset_type: Option<TypeKind>,
// Scalar index, paired with bases
gather_index_type: Option<TypeKind>,
// Scalable vector of offsets
gather_offsets_type: Option<TypeKind>,
// Scalable vector of indices
gather_indices_type: Option<TypeKind>,
}
impl LdIntrCharacteristics {
fn new(intr: &Intrinsic) -> Result<LdIntrCharacteristics, String> {
let input = intr.input.types.get(0).unwrap().get(0).unwrap();
let load_type = input
.get(intr.test.get_typeset_index().unwrap())
.and_then(InputType::typekind)
.and_then(TypeKind::base_type)
.unwrap();
let ret_type = intr.signature.return_type.clone();
let name = intr.signature.fn_name().to_string();
let tuple_len = name
.chars()
.find(|c| c.is_numeric())
.and_then(|c| c.to_digit(10))
.unwrap_or(1) as usize;
let uses_ffr = name.starts_with("svldff") || name.starts_with("svldnf");
let is_prf = name.starts_with("svprf");
let replicate_width = if name.starts_with("svld1ro") {
Some(256)
} else if name.starts_with("svld1rq") {
Some(128)
} else {
None
};
let get_ty_of_arg = |name: &str| {
intr.signature
.arguments
.iter()
.find(|a| a.name.to_string() == name)
.map(|a| a.kind.clone())
};
let gather_bases_type = get_ty_of_arg("bases");
let gather_offset_type = get_ty_of_arg("offset");
let gather_index_type = get_ty_of_arg("index");
let gather_offsets_type = get_ty_of_arg("offsets");
let gather_indices_type = get_ty_of_arg("indices");
Ok(LdIntrCharacteristics {
load_type: *load_type,
ret_type,
tuple_len,
vnum: name.contains("vnum"),
uses_ffr,
is_prf,
replicate_width,
gather_bases_type,
gather_offset_type,
gather_index_type,
gather_offsets_type,
gather_indices_type,
})
}
}
lazy_static! {
static ref PREAMBLE: String = format!(
r#"#![allow(unused)]
use super::*;
use std::boxed::Box;
use std::convert::{{TryFrom, TryInto}};
use std::sync::LazyLock;
use std::vec::Vec;
use stdarch_test::simd_test;
static F32_DATA: LazyLock<[f32; {LEN_F32} * {NUM_VECS}]> = LazyLock::new(|| {{
(0..{LEN_F32} * {NUM_VECS})
.map(|i| i as f32)
.collect::<Vec<_>>()
.try_into()
.expect("f32 data incorrectly initialised")
}});
static F64_DATA: LazyLock<[f64; {LEN_F64} * {NUM_VECS}]> = LazyLock::new(|| {{
(0..{LEN_F64} * {NUM_VECS})
.map(|i| i as f64)
.collect::<Vec<_>>()
.try_into()
.expect("f64 data incorrectly initialised")
}});
static I8_DATA: LazyLock<[i8; {LEN_I8} * {NUM_VECS}]> = LazyLock::new(|| {{
(0..{LEN_I8} * {NUM_VECS})
.map(|i| ((i + 128) % 256 - 128) as i8)
.collect::<Vec<_>>()
.try_into()
.expect("i8 data incorrectly initialised")
}});
static I16_DATA: LazyLock<[i16; {LEN_I16} * {NUM_VECS}]> = LazyLock::new(|| {{
(0..{LEN_I16} * {NUM_VECS})
.map(|i| i as i16)
.collect::<Vec<_>>()
.try_into()
.expect("i16 data incorrectly initialised")
}});
static I32_DATA: LazyLock<[i32; {LEN_I32} * {NUM_VECS}]> = LazyLock::new(|| {{
(0..{LEN_I32} * {NUM_VECS})
.map(|i| i as i32)
.collect::<Vec<_>>()
.try_into()
.expect("i32 data incorrectly initialised")
}});
static I64_DATA: LazyLock<[i64; {LEN_I64} * {NUM_VECS}]> = LazyLock::new(|| {{
(0..{LEN_I64} * {NUM_VECS})
.map(|i| i as i64)
.collect::<Vec<_>>()
.try_into()
.expect("i64 data incorrectly initialised")
}});
static U8_DATA: LazyLock<[u8; {LEN_U8} * {NUM_VECS}]> = LazyLock::new(|| {{
(0..{LEN_U8} * {NUM_VECS})
.map(|i| i as u8)
.collect::<Vec<_>>()
.try_into()
.expect("u8 data incorrectly initialised")
}});
static U16_DATA: LazyLock<[u16; {LEN_U16} * {NUM_VECS}]> = LazyLock::new(|| {{
(0..{LEN_U16} * {NUM_VECS})
.map(|i| i as u16)
.collect::<Vec<_>>()
.try_into()
.expect("u16 data incorrectly initialised")
}});
static U32_DATA: LazyLock<[u32; {LEN_U32} * {NUM_VECS}]> = LazyLock::new(|| {{
(0..{LEN_U32} * {NUM_VECS})
.map(|i| i as u32)
.collect::<Vec<_>>()
.try_into()
.expect("u32 data incorrectly initialised")
}});
static U64_DATA: LazyLock<[u64; {LEN_U64} * {NUM_VECS}]> = LazyLock::new(|| {{
(0..{LEN_U64} * {NUM_VECS})
.map(|i| i as u64)
.collect::<Vec<_>>()
.try_into()
.expect("u64 data incorrectly initialised")
}});
#[target_feature(enable = "sve")]
fn assert_vector_matches_f32(vector: svfloat32_t, expected: svfloat32_t) {{
let defined = svrdffr();
assert!(svptest_first(svptrue_b32(), defined));
let cmp = svcmpne_f32(defined, vector, expected);
assert!(!svptest_any(defined, cmp))
}}
#[target_feature(enable = "sve")]
fn assert_vector_matches_f64(vector: svfloat64_t, expected: svfloat64_t) {{
let defined = svrdffr();
assert!(svptest_first(svptrue_b64(), defined));
let cmp = svcmpne_f64(defined, vector, expected);
assert!(!svptest_any(defined, cmp))
}}
#[target_feature(enable = "sve")]
fn assert_vector_matches_i8(vector: svint8_t, expected: svint8_t) {{
let defined = svrdffr();
assert!(svptest_first(svptrue_b8(), defined));
let cmp = svcmpne_s8(defined, vector, expected);
assert!(!svptest_any(defined, cmp))
}}
#[target_feature(enable = "sve")]
fn assert_vector_matches_i16(vector: svint16_t, expected: svint16_t) {{
let defined = svrdffr();
assert!(svptest_first(svptrue_b16(), defined));
let cmp = svcmpne_s16(defined, vector, expected);
assert!(!svptest_any(defined, cmp))
}}
#[target_feature(enable = "sve")]
fn assert_vector_matches_i32(vector: svint32_t, expected: svint32_t) {{
let defined = svrdffr();
assert!(svptest_first(svptrue_b32(), defined));
let cmp = svcmpne_s32(defined, vector, expected);
assert!(!svptest_any(defined, cmp))
}}
#[target_feature(enable = "sve")]
fn assert_vector_matches_i64(vector: svint64_t, expected: svint64_t) {{
let defined = svrdffr();
assert!(svptest_first(svptrue_b64(), defined));
let cmp = svcmpne_s64(defined, vector, expected);
assert!(!svptest_any(defined, cmp))
}}
#[target_feature(enable = "sve")]
fn assert_vector_matches_u8(vector: svuint8_t, expected: svuint8_t) {{
let defined = svrdffr();
assert!(svptest_first(svptrue_b8(), defined));
let cmp = svcmpne_u8(defined, vector, expected);
assert!(!svptest_any(defined, cmp))
}}
#[target_feature(enable = "sve")]
fn assert_vector_matches_u16(vector: svuint16_t, expected: svuint16_t) {{
let defined = svrdffr();
assert!(svptest_first(svptrue_b16(), defined));
let cmp = svcmpne_u16(defined, vector, expected);
assert!(!svptest_any(defined, cmp))
}}
#[target_feature(enable = "sve")]
fn assert_vector_matches_u32(vector: svuint32_t, expected: svuint32_t) {{
let defined = svrdffr();
assert!(svptest_first(svptrue_b32(), defined));
let cmp = svcmpne_u32(defined, vector, expected);
assert!(!svptest_any(defined, cmp))
}}
#[target_feature(enable = "sve")]
fn assert_vector_matches_u64(vector: svuint64_t, expected: svuint64_t) {{
let defined = svrdffr();
assert!(svptest_first(svptrue_b64(), defined));
let cmp = svcmpne_u64(defined, vector, expected);
assert!(!svptest_any(defined, cmp))
}}
"#
);
}
lazy_static! {
static ref MANUAL_TESTS: String = format!(
"#[simd_test(enable = \"sve\")]
unsafe fn test_ffr() {{
svsetffr();
let ffr = svrdffr();
assert_vector_matches_u8(svdup_n_u8_z(ffr, 1), svindex_u8(1, 0));
let pred = svdupq_n_b8(true, false, true, false, true, false, true, false,
true, false, true, false, true, false, true, false);
svwrffr(pred);
let ffr = svrdffr_z(svptrue_b8());
assert_vector_matches_u8(svdup_n_u8_z(ffr, 1), svdup_n_u8_z(pred, 1));
}}
"
);
}

View file

@ -0,0 +1,273 @@
#![feature(pattern)]
mod assert_instr;
mod context;
mod expression;
mod input;
mod intrinsic;
mod load_store_tests;
mod matching;
mod predicate_forms;
mod typekinds;
mod wildcards;
mod wildstring;
use intrinsic::Test;
use itertools::Itertools;
use quote::quote;
use std::fs::File;
use std::io::Write;
use std::path::{Path, PathBuf};
use std::process::{Command, Stdio};
use walkdir::WalkDir;
fn main() -> Result<(), String> {
parse_args()
.into_iter()
.map(|(filepath, out)| {
File::open(&filepath)
.map(|f| (f, filepath, out))
.map_err(|e| format!("could not read input file: {e}"))
})
.map(|res| {
let (file, filepath, out) = res?;
serde_yaml::from_reader(file)
.map(|input: input::GeneratorInput| (input, filepath, out))
.map_err(|e| format!("could not parse input file: {e}"))
})
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.map(|(input, filepath, out)| {
let intrinsics = input.intrinsics.into_iter()
.map(|intrinsic| intrinsic.generate_variants(&input.ctx))
.try_collect()
.map(|mut vv: Vec<_>| {
vv.sort_by_cached_key(|variants| {
variants.first().map_or_else(String::default, |variant| {
variant.signature.fn_name().to_string()
})
});
vv.into_iter().flatten().collect_vec()
})?;
let loads = intrinsics.iter()
.filter_map(|i| {
if matches!(i.test, Test::Load(..)) {
Some(i.clone())
} else {
None
}
}).collect();
let stores = intrinsics.iter()
.filter_map(|i| {
if matches!(i.test, Test::Store(..)) {
Some(i.clone())
} else {
None
}
}).collect();
load_store_tests::generate_load_store_tests(loads, stores, out.as_ref().map(|o| make_tests_filepath(&filepath, o)).as_ref())?;
Ok((
input::GeneratorInput {
intrinsics,
ctx: input.ctx,
},
filepath,
out,
))
})
.try_for_each(
|result: context::Result<(input::GeneratorInput, PathBuf, Option<PathBuf>)>| -> context::Result {
let (generated, filepath, out) = result?;
let w = match out {
Some(out) => Box::new(
File::create(make_output_filepath(&filepath, &out))
.map_err(|e| format!("could not create output file: {e}"))?,
) as Box<dyn Write>,
None => Box::new(std::io::stdout()) as Box<dyn Write>,
};
generate_file(generated, w)
.map_err(|e| format!("could not generate output file: {e}"))
},
)
}
fn parse_args() -> Vec<(PathBuf, Option<PathBuf>)> {
let mut args_it = std::env::args().skip(1);
assert!(
1 <= args_it.len() && args_it.len() <= 2,
"Usage: cargo run -p stdarch-gen2 -- INPUT_DIR [OUTPUT_DIR]"
);
let in_path = Path::new(args_it.next().unwrap().as_str()).to_path_buf();
assert!(
in_path.exists() && in_path.is_dir(),
"invalid path {in_path:#?} given"
);
let out_dir = if let Some(dir) = args_it.next() {
let out_path = Path::new(dir.as_str()).to_path_buf();
assert!(
out_path.exists() && out_path.is_dir(),
"invalid path {out_path:#?} given"
);
Some(out_path)
} else {
std::env::current_exe()
.map(|mut f| {
f.pop();
f.push("../../crates/core_arch/src/aarch64/");
f.exists().then_some(f)
})
.ok()
.flatten()
};
WalkDir::new(in_path)
.into_iter()
.filter_map(Result::ok)
.filter(|f| f.file_type().is_file())
.map(|f| (f.into_path(), out_dir.clone()))
.collect()
}
fn generate_file(
generated_input: input::GeneratorInput,
mut out: Box<dyn Write>,
) -> std::io::Result<()> {
write!(
out,
r#"// This code is automatically generated. DO NOT MODIFY.
//
// Instead, modify `crates/stdarch-gen2/spec/` and run the following command to re-generate this file:
//
// ```
// cargo run --bin=stdarch-gen2 -- crates/stdarch-gen2/spec
// ```
#![allow(improper_ctypes)]
#[cfg(test)]
use stdarch_test::assert_instr;
use super::*;{uses_neon}
"#,
uses_neon = generated_input
.ctx
.uses_neon_types
.then_some("\nuse crate::core_arch::arch::aarch64::*;")
.unwrap_or_default(),
)?;
let intrinsics = generated_input.intrinsics;
format_code(out, quote! { #(#intrinsics)* })?;
Ok(())
}
pub fn format_code(
mut output: impl std::io::Write,
input: impl std::fmt::Display,
) -> std::io::Result<()> {
let proc = Command::new("rustfmt")
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.spawn()?;
write!(proc.stdin.as_ref().unwrap(), "{input}")?;
output.write_all(proc.wait_with_output()?.stdout.as_slice())
}
/// Derive an output file name from an input file and an output directory.
///
/// The name is formed by:
///
/// - ... taking in_filepath.file_name() (dropping all directory components),
/// - ... dropping a .yml or .yaml extension (if present),
/// - ... then dropping a .spec extension (if present).
///
/// Panics if the resulting name is empty, or if file_name() is not UTF-8.
fn make_output_filepath(in_filepath: &Path, out_dirpath: &Path) -> PathBuf {
make_filepath(in_filepath, out_dirpath, |name: &str| format!("{name}.rs"))
}
fn make_tests_filepath(in_filepath: &Path, out_dirpath: &Path) -> PathBuf {
make_filepath(in_filepath, out_dirpath, |name: &str| {
format!("ld_st_tests_{name}.rs")
})
}
fn make_filepath<F: FnOnce(&str) -> String>(
in_filepath: &Path,
out_dirpath: &Path,
name_formatter: F,
) -> PathBuf {
let mut parts = in_filepath.iter();
let name = parts
.next_back()
.and_then(|f| f.to_str())
.expect("Inputs must have valid, UTF-8 file_name()");
let dir = parts.next_back().unwrap();
let name = name
.trim_end_matches(".yml")
.trim_end_matches(".yaml")
.trim_end_matches(".spec");
assert!(!name.is_empty());
let mut output = out_dirpath.to_path_buf();
output.push(dir);
output.push(name_formatter(name));
output
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn infer_output_file() {
macro_rules! t {
($src:expr, $outdir:expr, $dst:expr) => {
let src: PathBuf = $src.iter().collect();
let outdir: PathBuf = $outdir.iter().collect();
let dst: PathBuf = $dst.iter().collect();
assert_eq!(make_output_filepath(&src, &outdir), dst);
};
}
// Documented usage.
t!(["x", "NAME.spec.yml"], [""], ["x", "NAME.rs"]);
t!(
["x", "NAME.spec.yml"],
["a", "b"],
["a", "b", "x", "NAME.rs"]
);
t!(
["x", "y", "NAME.spec.yml"],
["out"],
["out", "y", "NAME.rs"]
);
t!(["x", "NAME.spec.yaml"], ["out"], ["out", "x", "NAME.rs"]);
t!(["x", "NAME.spec"], ["out"], ["out", "x", "NAME.rs"]);
t!(["x", "NAME.yml"], ["out"], ["out", "x", "NAME.rs"]);
t!(["x", "NAME.yaml"], ["out"], ["out", "x", "NAME.rs"]);
// Unrecognised extensions get treated as part of the stem.
t!(
["x", "NAME.spac.yml"],
["out"],
["out", "x", "NAME.spac.rs"]
);
t!(["x", "NAME.txt"], ["out"], ["out", "x", "NAME.txt.rs"]);
// Always take the top-level directory from the input path
t!(
["x", "y", "z", "NAME.spec.yml"],
["out"],
["out", "z", "NAME.rs"]
);
}
#[test]
#[should_panic]
fn infer_output_file_no_stem() {
make_output_filepath(Path::new(".spec.yml"), Path::new(""));
}
}

View file

@ -0,0 +1,170 @@
use proc_macro2::TokenStream;
use quote::ToTokens;
use serde::{Deserialize, Serialize};
use std::fmt;
use crate::context::{self, LocalContext};
use crate::typekinds::{BaseType, BaseTypeKind, TypeKind};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct MatchSizeValues<T> {
pub default: T,
pub byte: Option<T>,
pub halfword: Option<T>,
pub doubleword: Option<T>,
}
impl<T> MatchSizeValues<T> {
pub fn get(&mut self, ty: &TypeKind, ctx: &LocalContext) -> context::Result<&T> {
let base_ty = if let Some(w) = ty.wildcard() {
ctx.provide_type_wildcard(w)?
} else {
ty.clone()
};
if let BaseType::Sized(_, bitsize) = base_ty.base_type().unwrap() {
match (bitsize, &self.byte, &self.halfword, &self.doubleword) {
(64, _, _, Some(v)) | (16, _, Some(v), _) | (8, Some(v), _, _) => Ok(v),
_ => Ok(&self.default),
}
} else {
Err(format!("cannot match bitsize to unsized type {ty:?}!"))
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct MatchKindValues<T> {
pub default: T,
pub float: Option<T>,
pub unsigned: Option<T>,
}
impl<T> MatchKindValues<T> {
pub fn get(&mut self, ty: &TypeKind, ctx: &LocalContext) -> context::Result<&T> {
let base_ty = if let Some(w) = ty.wildcard() {
ctx.provide_type_wildcard(w)?
} else {
ty.clone()
};
match (
base_ty.base_type().unwrap().kind(),
&self.float,
&self.unsigned,
) {
(BaseTypeKind::Float, Some(v), _) | (BaseTypeKind::UInt, _, Some(v)) => Ok(v),
_ => Ok(&self.default),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged, deny_unknown_fields)]
pub enum SizeMatchable<T> {
Matched(T),
Unmatched {
match_size: Option<TypeKind>,
#[serde(flatten)]
values: MatchSizeValues<Box<T>>,
},
}
impl<T: Clone> SizeMatchable<T> {
pub fn perform_match(&mut self, ctx: &LocalContext) -> context::Result {
match self {
Self::Unmatched {
match_size: None,
values: MatchSizeValues { default, .. },
} => *self = Self::Matched(*default.to_owned()),
Self::Unmatched {
match_size: Some(ty),
values,
} => *self = Self::Matched(*values.get(ty, ctx)?.to_owned()),
_ => {}
}
Ok(())
}
}
impl<T: fmt::Debug> AsRef<T> for SizeMatchable<T> {
fn as_ref(&self) -> &T {
if let SizeMatchable::Matched(v) = self {
v
} else {
panic!("no match for {self:?} was performed");
}
}
}
impl<T: fmt::Debug> AsMut<T> for SizeMatchable<T> {
fn as_mut(&mut self) -> &mut T {
if let SizeMatchable::Matched(v) = self {
v
} else {
panic!("no match for {self:?} was performed");
}
}
}
impl<T: fmt::Debug + ToTokens> ToTokens for SizeMatchable<T> {
fn to_tokens(&self, tokens: &mut TokenStream) {
self.as_ref().to_tokens(tokens)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged, deny_unknown_fields)]
pub enum KindMatchable<T> {
Matched(T),
Unmatched {
match_kind: Option<TypeKind>,
#[serde(flatten)]
values: MatchKindValues<Box<T>>,
},
}
impl<T: Clone> KindMatchable<T> {
pub fn perform_match(&mut self, ctx: &LocalContext) -> context::Result {
match self {
Self::Unmatched {
match_kind: None,
values: MatchKindValues { default, .. },
} => *self = Self::Matched(*default.to_owned()),
Self::Unmatched {
match_kind: Some(ty),
values,
} => *self = Self::Matched(*values.get(ty, ctx)?.to_owned()),
_ => {}
}
Ok(())
}
}
impl<T: fmt::Debug> AsRef<T> for KindMatchable<T> {
fn as_ref(&self) -> &T {
if let KindMatchable::Matched(v) = self {
v
} else {
panic!("no match for {self:?} was performed");
}
}
}
impl<T: fmt::Debug> AsMut<T> for KindMatchable<T> {
fn as_mut(&mut self) -> &mut T {
if let KindMatchable::Matched(v) = self {
v
} else {
panic!("no match for {self:?} was performed");
}
}
}
impl<T: fmt::Debug + ToTokens> ToTokens for KindMatchable<T> {
fn to_tokens(&self, tokens: &mut TokenStream) {
self.as_ref().to_tokens(tokens)
}
}

View file

@ -0,0 +1,249 @@
use serde::{Deserialize, Serialize};
use serde_with::{DeserializeFromStr, SerializeDisplay};
use std::fmt;
use std::str::FromStr;
use crate::context;
use crate::expression::{Expression, FnCall, IdentifierType};
use crate::intrinsic::Intrinsic;
use crate::typekinds::{ToRepr, TypeKind};
use crate::wildcards::Wildcard;
use crate::wildstring::WildString;
const ZEROING_SUFFIX: &str = "_z";
const MERGING_SUFFIX: &str = "_m";
const DONT_CARE_SUFFIX: &str = "_x";
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ZeroingMethod {
/// Drop the specified argument and replace it with a zeroinitializer
Drop { drop: WildString },
/// Apply zero selection to the specified variable when zeroing
Select { select: WildString },
}
impl PartialOrd for ZeroingMethod {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for ZeroingMethod {
fn cmp(&self, _: &Self) -> std::cmp::Ordering {
std::cmp::Ordering::Equal
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub enum DontCareMethod {
#[default]
Inferred,
AsZeroing,
AsMerging,
}
#[derive(Debug, Clone, Default, PartialEq, Eq, Deserialize, Serialize)]
pub struct PredicationMethods {
/// Zeroing method, if the zeroing predicate form is used
#[serde(default)]
pub zeroing_method: Option<ZeroingMethod>,
/// Don't care method, if the don't care predicate form is used
#[serde(default)]
pub dont_care_method: DontCareMethod,
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub enum PredicateForm {
/// Enables merging predicate form
Merging,
/// Enables "don't care" predicate form.
DontCare(DontCareMethod),
/// Enables zeroing predicate form. If LLVM zeroselection is performed, then
/// set the `select` field to the variable that gets set. Otherwise set the
/// `drop` field if the zeroinitializer replaces a predicate when merging.
Zeroing(ZeroingMethod),
}
impl PredicateForm {
pub fn get_suffix(&self) -> &'static str {
match self {
PredicateForm::Zeroing { .. } => ZEROING_SUFFIX,
PredicateForm::Merging => MERGING_SUFFIX,
PredicateForm::DontCare { .. } => DONT_CARE_SUFFIX,
}
}
pub fn make_zeroinitializer(ty: &TypeKind) -> Expression {
FnCall::new_expression(
format!("svdup_n_{}", ty.acle_notation_repr())
.parse()
.unwrap(),
vec![if ty.base_type().unwrap().is_float() {
Expression::FloatConstant(0.0)
} else {
Expression::IntConstant(0)
}],
)
}
pub fn make_zeroselector(pg_var: WildString, op_var: WildString, ty: &TypeKind) -> Expression {
FnCall::new_expression(
format!("svsel_{}", ty.acle_notation_repr())
.parse()
.unwrap(),
vec![
Expression::Identifier(pg_var, IdentifierType::Variable),
Expression::Identifier(op_var, IdentifierType::Variable),
Self::make_zeroinitializer(ty),
],
)
}
pub fn post_build(&self, intrinsic: &mut Intrinsic) -> context::Result {
// Drop the argument
match self {
PredicateForm::Zeroing(ZeroingMethod::Drop { drop: drop_var }) => {
intrinsic.signature.drop_argument(drop_var)?
}
PredicateForm::DontCare(DontCareMethod::AsZeroing) => {
if let ZeroingMethod::Drop { drop } = intrinsic
.input
.predication_methods
.zeroing_method
.to_owned()
.ok_or_else(|| {
"DontCareMethod::AsZeroing without zeroing method.".to_string()
})?
{
intrinsic.signature.drop_argument(&drop)?
}
}
_ => {}
}
Ok(())
}
fn infer_dont_care(mask: &PredicationMask, methods: &PredicationMethods) -> PredicateForm {
let method = if methods.dont_care_method == DontCareMethod::Inferred {
if mask.has_zeroing()
&& matches!(methods.zeroing_method, Some(ZeroingMethod::Drop { .. }))
{
DontCareMethod::AsZeroing
} else {
DontCareMethod::AsMerging
}
} else {
methods.dont_care_method
};
PredicateForm::DontCare(method)
}
pub fn compile_list(
mask: &PredicationMask,
methods: &PredicationMethods,
) -> context::Result<Vec<PredicateForm>> {
let mut forms = Vec::new();
if mask.has_merging() {
forms.push(PredicateForm::Merging)
}
if mask.has_dont_care() {
forms.push(Self::infer_dont_care(mask, methods))
}
if mask.has_zeroing() {
if let Some(method) = methods.zeroing_method.to_owned() {
forms.push(PredicateForm::Zeroing(method))
} else {
return Err(
"cannot create a zeroing variant without a zeroing method specified!"
.to_string(),
);
}
}
Ok(forms)
}
}
#[derive(
Debug, Clone, Copy, Default, PartialEq, Eq, Hash, DeserializeFromStr, SerializeDisplay,
)]
pub struct PredicationMask {
/// Merging
m: bool,
/// Don't care
x: bool,
/// Zeroing
z: bool,
}
impl PredicationMask {
pub fn has_merging(&self) -> bool {
self.m
}
pub fn has_dont_care(&self) -> bool {
self.x
}
pub fn has_zeroing(&self) -> bool {
self.z
}
}
impl FromStr for PredicationMask {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let mut result = Self::default();
for kind in s.bytes() {
match kind {
b'm' => result.m = true,
b'x' => result.x = true,
b'z' => result.z = true,
_ => {
return Err(format!(
"unknown predicate form modifier: {}",
char::from(kind)
));
}
}
}
if result.m || result.x || result.z {
Ok(result)
} else {
Err("invalid predication mask".to_string())
}
}
}
impl fmt::Display for PredicationMask {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.m.then(|| write!(f, "m")).transpose()?;
self.x.then(|| write!(f, "x")).transpose()?;
self.z.then(|| write!(f, "z")).transpose().map(|_| ())
}
}
impl TryFrom<&WildString> for PredicationMask {
type Error = String;
fn try_from(value: &WildString) -> Result<Self, Self::Error> {
value
.wildcards()
.find_map(|w| {
if let Wildcard::PredicateForms(mask) = w {
Some(*mask)
} else {
None
}
})
.ok_or_else(|| "no predicate forms were specified in the name".to_string())
}
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,179 @@
use lazy_static::lazy_static;
use regex::Regex;
use serde_with::{DeserializeFromStr, SerializeDisplay};
use std::fmt;
use std::str::FromStr;
use crate::{
predicate_forms::PredicationMask,
typekinds::{ToRepr, TypeKind, TypeKindOptions, VectorTupleSize},
};
#[derive(Debug, Clone, PartialEq, Eq, Hash, SerializeDisplay, DeserializeFromStr)]
pub enum Wildcard {
Type(Option<usize>),
/// NEON type derivated by a base type
NEONType(Option<usize>, Option<VectorTupleSize>),
/// SVE type derivated by a base type
SVEType(Option<usize>, Option<VectorTupleSize>),
/// Integer representation of bitsize
Size(Option<usize>),
/// Integer representation of bitsize minus one
SizeMinusOne(Option<usize>),
/// Literal representation of the bitsize: b(yte), h(half), w(ord) or d(ouble)
SizeLiteral(Option<usize>),
/// Literal representation of the type kind: f(loat), s(igned), u(nsigned)
TypeKind(Option<usize>, Option<TypeKindOptions>),
/// Log2 of the size in bytes
SizeInBytesLog2(Option<usize>),
/// Predicate to be inferred from the specified type
Predicate(Option<usize>),
/// Predicate to be inferred from the greatest type
MaxPredicate,
Scale(Box<Wildcard>, Box<TypeKind>),
// Other wildcards
LLVMLink,
NVariant,
/// Predicate forms to use and placeholder for a predicate form function name modifier
PredicateForms(PredicationMask),
/// User-set wildcard through `substitutions`
Custom(String),
}
impl Wildcard {
pub fn is_nonpredicate_type(&self) -> bool {
matches!(
self,
Wildcard::Type(..) | Wildcard::NEONType(..) | Wildcard::SVEType(..)
)
}
pub fn get_typeset_index(&self) -> Option<usize> {
match self {
Wildcard::Type(idx) | Wildcard::NEONType(idx, ..) | Wildcard::SVEType(idx, ..) => {
Some(idx.unwrap_or(0))
}
_ => None,
}
}
}
impl FromStr for Wildcard {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
lazy_static! {
static ref RE: Regex = Regex::new(r"^(?P<wildcard>\w+?)(?:_x(?P<tuple_size>[2-4]))?(?:\[(?P<index>\d+)\])?(?:\.(?P<modifiers>\w+))?(?:\s+as\s+(?P<scale_to>.*?))?$").unwrap();
}
if let Some(c) = RE.captures(s) {
let wildcard_name = &c["wildcard"];
let inputset_index = c
.name("index")
.map(<&str>::from)
.map(usize::from_str)
.transpose()
.map_err(|_| format!("{:#?} is not a valid type index", &c["index"]))?;
let tuple_size = c
.name("tuple_size")
.map(<&str>::from)
.map(VectorTupleSize::from_str)
.transpose()
.map_err(|_| format!("{:#?} is not a valid tuple size", &c["tuple_size"]))?;
let modifiers = c.name("modifiers").map(<&str>::from);
let wildcard = match (wildcard_name, inputset_index, tuple_size, modifiers) {
("type", index, None, None) => Ok(Wildcard::Type(index)),
("neon_type", index, tuple, None) => Ok(Wildcard::NEONType(index, tuple)),
("sve_type", index, tuple, None) => Ok(Wildcard::SVEType(index, tuple)),
("size", index, None, None) => Ok(Wildcard::Size(index)),
("size_minus_one", index, None, None) => Ok(Wildcard::SizeMinusOne(index)),
("size_literal", index, None, None) => Ok(Wildcard::SizeLiteral(index)),
("type_kind", index, None, modifiers) => Ok(Wildcard::TypeKind(
index,
modifiers.map(|modifiers| modifiers.parse()).transpose()?,
)),
("size_in_bytes_log2", index, None, None) => Ok(Wildcard::SizeInBytesLog2(index)),
("predicate", index, None, None) => Ok(Wildcard::Predicate(index)),
("max_predicate", None, None, None) => Ok(Wildcard::MaxPredicate),
("llvm_link", None, None, None) => Ok(Wildcard::LLVMLink),
("_n", None, None, None) => Ok(Wildcard::NVariant),
(w, None, None, None) if w.starts_with('_') => {
// test for predicate forms
let pf_mask = PredicationMask::from_str(&w[1..]);
if let Ok(mask) = pf_mask {
if mask.has_merging() {
Ok(Wildcard::PredicateForms(mask))
} else {
Err("cannot add predication without a Merging form".to_string())
}
} else {
Err(format!("invalid wildcard `{s:#?}`"))
}
}
(cw, None, None, None) => Ok(Wildcard::Custom(cw.to_string())),
_ => Err(format!("invalid wildcard `{s:#?}`")),
}?;
let scale_to = c
.name("scale_to")
.map(<&str>::from)
.map(TypeKind::from_str)
.transpose()
.map_err(|_| format!("{:#?} is not a valid type", &c["scale_to"]))?;
if let Some(scale_to) = scale_to {
Ok(Wildcard::Scale(Box::new(wildcard), Box::new(scale_to)))
} else {
Ok(wildcard)
}
} else {
Err(format!("invalid wildcard `{s:#?}`"))
}
}
}
impl fmt::Display for Wildcard {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Type(None) => write!(f, "type"),
Self::Type(Some(index)) => write!(f, "type[{index}]"),
Self::NEONType(None, None) => write!(f, "neon_type"),
Self::NEONType(Some(index), None) => write!(f, "neon_type[{index}]"),
Self::NEONType(None, Some(tuple_size)) => write!(f, "neon_type_x{tuple_size}"),
Self::NEONType(Some(index), Some(tuple_size)) => {
write!(f, "neon_type_x{tuple_size}[{index}]")
}
Self::SVEType(None, None) => write!(f, "sve_type"),
Self::SVEType(Some(index), None) => write!(f, "sve_type[{index}]"),
Self::SVEType(None, Some(tuple_size)) => write!(f, "sve_type_x{tuple_size}"),
Self::SVEType(Some(index), Some(tuple_size)) => {
write!(f, "sve_type_x{tuple_size}[{index}]")
}
Self::Size(None) => write!(f, "size"),
Self::Size(Some(index)) => write!(f, "size[{index}]"),
Self::SizeMinusOne(None) => write!(f, "size_minus_one"),
Self::SizeMinusOne(Some(index)) => write!(f, "size_minus_one[{index}]"),
Self::SizeLiteral(None) => write!(f, "size_literal"),
Self::SizeLiteral(Some(index)) => write!(f, "size_literal[{index}]"),
Self::TypeKind(None, None) => write!(f, "type_kind"),
Self::TypeKind(None, Some(opts)) => write!(f, "type_kind.{opts}"),
Self::TypeKind(Some(index), None) => write!(f, "type_kind[{index}]"),
Self::TypeKind(Some(index), Some(opts)) => write!(f, "type_kind[{index}].{opts}"),
Self::SizeInBytesLog2(None) => write!(f, "size_in_bytes_log2"),
Self::SizeInBytesLog2(Some(index)) => write!(f, "size_in_bytes_log2[{index}]"),
Self::Predicate(None) => write!(f, "predicate"),
Self::Predicate(Some(index)) => write!(f, "predicate[{index}]"),
Self::MaxPredicate => write!(f, "max_predicate"),
Self::LLVMLink => write!(f, "llvm_link"),
Self::NVariant => write!(f, "_n"),
Self::PredicateForms(mask) => write!(f, "_{mask}"),
Self::Scale(wildcard, ty) => write!(f, "{wildcard} as {}", ty.rust_repr()),
Self::Custom(cw) => write!(f, "{cw}"),
}
}
}

View file

@ -0,0 +1,353 @@
use itertools::Itertools;
use proc_macro2::TokenStream;
use quote::{quote, ToTokens, TokenStreamExt};
use serde_with::{DeserializeFromStr, SerializeDisplay};
use std::str::pattern::Pattern;
use std::{fmt, str::FromStr};
use crate::context::LocalContext;
use crate::typekinds::{ToRepr, TypeRepr};
use crate::wildcards::Wildcard;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum WildStringPart {
String(String),
Wildcard(Wildcard),
}
/// Wildcard-able string
#[derive(Debug, Clone, PartialEq, Eq, Default, SerializeDisplay, DeserializeFromStr)]
pub struct WildString(Vec<WildStringPart>);
impl WildString {
pub fn has_wildcards(&self) -> bool {
for part in self.0.iter() {
if let WildStringPart::Wildcard(..) = part {
return true;
}
}
false
}
pub fn wildcards(&self) -> impl Iterator<Item = &Wildcard> + '_ {
self.0.iter().filter_map(|part| match part {
WildStringPart::Wildcard(w) => Some(w),
_ => None,
})
}
pub fn iter(&self) -> impl Iterator<Item = &WildStringPart> + '_ {
self.0.iter()
}
pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut WildStringPart> + '_ {
self.0.iter_mut()
}
pub fn starts_with(&self, s2: &str) -> bool {
self.to_string().starts_with(s2)
}
pub fn prepend_str(&mut self, s: impl Into<String>) {
self.0.insert(0, WildStringPart::String(s.into()))
}
pub fn push_str(&mut self, s: impl Into<String>) {
self.0.push(WildStringPart::String(s.into()))
}
pub fn push_wildcard(&mut self, w: Wildcard) {
self.0.push(WildStringPart::Wildcard(w))
}
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
pub fn replace<'a, P>(&'a self, from: P, to: &str) -> WildString
where
P: Pattern<'a> + Copy,
{
WildString(
self.0
.iter()
.map(|part| match part {
WildStringPart::String(s) => WildStringPart::String(s.replace(from, to)),
part => part.clone(),
})
.collect_vec(),
)
}
pub fn build_acle(&mut self, ctx: &LocalContext) -> Result<(), String> {
self.build(ctx, TypeRepr::ACLENotation)
}
pub fn build(&mut self, ctx: &LocalContext, repr: TypeRepr) -> Result<(), String> {
self.iter_mut().try_for_each(|wp| -> Result<(), String> {
if let WildStringPart::Wildcard(w) = wp {
let value = ctx
.provide_substitution_wildcard(w)
.or_else(|_| ctx.provide_type_wildcard(w).map(|ty| ty.repr(repr)))?;
*wp = WildStringPart::String(value);
}
Ok(())
})
}
}
impl From<String> for WildString {
fn from(s: String) -> Self {
WildString(vec![WildStringPart::String(s)])
}
}
impl FromStr for WildString {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
enum State {
Normal { start: usize },
Wildcard { start: usize, count: usize },
EscapeTokenOpen { start: usize, at: usize },
EscapeTokenClose { start: usize, at: usize },
}
let mut ws = WildString::default();
match s
.char_indices()
.try_fold(State::Normal { start: 0 }, |state, (idx, ch)| {
match (state, ch) {
(State::Normal { start }, '{') => Ok(State::EscapeTokenOpen { start, at: idx }),
(State::Normal { start }, '}') => {
Ok(State::EscapeTokenClose { start, at: idx })
}
(State::EscapeTokenOpen { start, at }, '{')
| (State::EscapeTokenClose { start, at }, '}') => {
if start < at {
ws.push_str(&s[start..at])
}
Ok(State::Normal { start: idx })
}
(State::EscapeTokenOpen { at, .. }, '}') => Err(format!(
"empty wildcard given in string {s:?} at position {at}"
)),
(State::EscapeTokenOpen { start, at }, _) => {
if start < at {
ws.push_str(&s[start..at])
}
Ok(State::Wildcard {
start: idx,
count: 0,
})
}
(State::EscapeTokenClose { at, .. }, _) => Err(format!(
"closing a non-wildcard/bad escape in string {s:?} at position {at}"
)),
// Nesting wildcards is only supported for `{foo as {bar}}`, wildcards cannot be
// nested at the start of a WildString.
(State::Wildcard { start, count }, '{') => Ok(State::Wildcard {
start,
count: count + 1,
}),
(State::Wildcard { start, count: 0 }, '}') => {
ws.push_wildcard(s[start..idx].parse()?);
Ok(State::Normal { start: idx + 1 })
}
(State::Wildcard { start, count }, '}') => Ok(State::Wildcard {
start,
count: count - 1,
}),
(state @ State::Normal { .. }, _) | (state @ State::Wildcard { .. }, _) => {
Ok(state)
}
}
})? {
State::Normal { start } => {
if start < s.len() {
ws.push_str(&s[start..]);
}
Ok(ws)
}
State::EscapeTokenOpen { at, .. } | State::Wildcard { start: at, .. } => Err(format!(
"unclosed wildcard in string {s:?} at position {at}"
)),
State::EscapeTokenClose { at, .. } => Err(format!(
"closing a non-wildcard/bad escape in string {s:?} at position {at}"
)),
}
}
}
impl fmt::Display for WildString {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"{}",
self.0
.iter()
.map(|part| match part {
WildStringPart::String(s) => s.to_owned(),
WildStringPart::Wildcard(w) => format!("{{{w}}}"),
})
.join("")
)
}
}
impl ToTokens for WildString {
fn to_tokens(&self, tokens: &mut TokenStream) {
assert!(
!self.has_wildcards(),
"cannot convert string with wildcards {self:?} to TokenStream"
);
let str = self.to_string();
tokens.append_all(quote! { #str })
}
}
#[cfg(test)]
mod tests {
use crate::typekinds::*;
use crate::wildstring::*;
#[test]
fn test_empty_string() {
let ws: WildString = "".parse().unwrap();
assert_eq!(ws.0.len(), 0);
}
#[test]
fn test_plain_string() {
let ws: WildString = "plain string".parse().unwrap();
assert_eq!(ws.0.len(), 1);
assert_eq!(
ws,
WildString(vec![WildStringPart::String("plain string".to_string())])
)
}
#[test]
fn test_escaped_curly_brackets() {
let ws: WildString = "VALUE = {{value}}".parse().unwrap();
assert_eq!(ws.to_string(), "VALUE = {value}");
assert!(!ws.has_wildcards());
}
#[test]
fn test_escaped_curly_brackets_wildcard() {
let ws: WildString = "TYPE = {{{type}}}".parse().unwrap();
assert_eq!(ws.to_string(), "TYPE = {{type}}");
assert_eq!(ws.0.len(), 4);
assert!(ws.has_wildcards());
}
#[test]
fn test_wildcard_right_boundary() {
let s = "string test {type}";
let ws: WildString = s.parse().unwrap();
assert_eq!(&ws.to_string(), s);
assert!(ws.has_wildcards());
}
#[test]
fn test_wildcard_left_boundary() {
let s = "{type} string test";
let ws: WildString = s.parse().unwrap();
assert_eq!(&ws.to_string(), s);
assert!(ws.has_wildcards());
}
#[test]
fn test_recursive_wildcard() {
let s = "string test {type[0] as {type[1]}}";
let ws: WildString = s.parse().unwrap();
assert_eq!(ws.0.len(), 2);
assert_eq!(
ws,
WildString(vec![
WildStringPart::String("string test ".to_string()),
WildStringPart::Wildcard(Wildcard::Scale(
Box::new(Wildcard::Type(Some(0))),
Box::new(TypeKind::Wildcard(Wildcard::Type(Some(1)))),
))
])
);
}
#[test]
fn test_scale_wildcard() {
let s = "string {type[0] as i8} test";
let ws: WildString = s.parse().unwrap();
assert_eq!(ws.0.len(), 3);
assert_eq!(
ws,
WildString(vec![
WildStringPart::String("string ".to_string()),
WildStringPart::Wildcard(Wildcard::Scale(
Box::new(Wildcard::Type(Some(0))),
Box::new(TypeKind::Base(BaseType::Sized(BaseTypeKind::Int, 8))),
)),
WildStringPart::String(" test".to_string())
])
);
}
#[test]
fn test_solitaire_wildcard() {
let ws: WildString = "{type}".parse().unwrap();
assert_eq!(ws.0.len(), 1);
assert_eq!(
ws,
WildString(vec![WildStringPart::Wildcard(Wildcard::Type(None))])
)
}
#[test]
fn test_empty_wildcard() {
"string {}"
.parse::<WildString>()
.expect_err("expected parse error");
}
#[test]
fn test_invalid_open_wildcard_right() {
"string {"
.parse::<WildString>()
.expect_err("expected parse error");
}
#[test]
fn test_invalid_close_wildcard_right() {
"string }"
.parse::<WildString>()
.expect_err("expected parse error");
}
#[test]
fn test_invalid_open_wildcard_left() {
"{string"
.parse::<WildString>()
.expect_err("expected parse error");
}
#[test]
fn test_invalid_close_wildcard_left() {
"}string"
.parse::<WildString>()
.expect_err("expected parse error");
}
#[test]
fn test_consecutive_wildcards() {
let s = "svprf{size_literal[1]}_gather_{type[0]}{index_or_offset}";
let ws: WildString = s.parse().unwrap();
assert_eq!(ws.to_string(), s)
}
}