aarch64: Fix streaming-compatible code with -mtrack-speculation [PR113805]

This patch makes -mtrack-speculation work on streaming-compatible
functions.  There were two related issues.  The first is that the
streaming-compatible code was using TB(N)Z unconditionally, whereas
those instructions are not allowed with speculation tracking.
That part can be fixed in a similar way to the recent eh_return
fix (PR112987).

The second issue was that the speculation-tracking pass runs
before some of the conditional branches are inserted.  It isn't
safe to insert the branches any earlier, so the patch instead adds
a second speculation-tracking pass that runs afterwards.  The new
pass is only used for streaming-compatible functions.

The testcase is adapted from call_sm_switch_1.c.

gcc/
	PR target/113805
	* config/aarch64/aarch64-passes.def (pass_late_track_speculation):
	New pass.
	* config/aarch64/aarch64-protos.h (make_pass_late_track_speculation):
	Declare.
	* config/aarch64/aarch64.md (is_call): New attribute.
	(*and<mode>3nr_compare0): Rename to...
	(@aarch64_and<mode>3nr_compare0): ...this.
	* config/aarch64/aarch64-sme.md (aarch64_get_sme_state)
	(aarch64_tpidr2_save, aarch64_tpidr2_restore): Add is_call attributes.
	* config/aarch64/aarch64-speculation.cc: Update file comment to
	describe the new late pass.
	(aarch64_do_track_speculation): Handle is_call insns like other calls.
	(pass_track_speculation): Add an is_late member variable.
	(pass_track_speculation::gate): Run the late pass for streaming-
	compatible functions and the early pass for other functions.
	(make_pass_track_speculation): Update accordingly.
	(make_pass_late_track_speculation): New function.
	* config/aarch64/aarch64.cc (aarch64_gen_test_and_branch): New
	function.
	(aarch64_guard_switch_pstate_sm): Use it.

gcc/testsuite/
	PR target/113805
	* gcc.target/aarch64/sme/call_sm_switch_11.c: New test.
This commit is contained in:
Richard Sandiford 2024-02-20 11:29:06 +00:00
parent ecfcc362b7
commit 98702303e2
7 changed files with 291 additions and 21 deletions

View file

@ -22,6 +22,7 @@ INSERT_PASS_BEFORE (pass_sched, 1, pass_aarch64_early_ra);
INSERT_PASS_AFTER (pass_regrename, 1, pass_fma_steering);
INSERT_PASS_BEFORE (pass_reorder_blocks, 1, pass_track_speculation);
INSERT_PASS_BEFORE (pass_late_thread_prologue_and_epilogue, 1, pass_switch_pstate_sm);
INSERT_PASS_BEFORE (pass_late_thread_prologue_and_epilogue, 1, pass_late_track_speculation);
INSERT_PASS_AFTER (pass_machine_reorg, 1, pass_tag_collision_avoidance);
INSERT_PASS_BEFORE (pass_shorten_branches, 1, pass_insert_bti);
INSERT_PASS_AFTER (pass_if_after_combine, 1, pass_cc_fusion);

View file

@ -1075,6 +1075,7 @@ std::string aarch64_get_extension_string_for_isa_flags (aarch64_feature_flags,
rtl_opt_pass *make_pass_aarch64_early_ra (gcc::context *);
rtl_opt_pass *make_pass_fma_steering (gcc::context *);
rtl_opt_pass *make_pass_track_speculation (gcc::context *);
rtl_opt_pass *make_pass_late_track_speculation (gcc::context *);
rtl_opt_pass *make_pass_tag_collision_avoidance (gcc::context *);
rtl_opt_pass *make_pass_insert_bti (gcc::context *ctxt);
rtl_opt_pass *make_pass_cc_fusion (gcc::context *ctxt);

View file

@ -105,6 +105,7 @@
(clobber (reg:CC CC_REGNUM))]
""
"bl\t__arm_sme_state"
[(set_attr "is_call" "yes")]
)
(define_insn "aarch64_read_svcr"
@ -242,6 +243,7 @@
(clobber (reg:CC CC_REGNUM))]
""
"bl\t__arm_tpidr2_save"
[(set_attr "is_call" "yes")]
)
;; Set PSTATE.ZA to 1. If ZA was previously dormant or active,
@ -358,6 +360,7 @@
(clobber (reg:CC CC_REGNUM))]
""
"bl\t__arm_tpidr2_restore"
[(set_attr "is_call" "yes")]
)
;; Check whether a lazy save set up by aarch64_save_za was committed

View file

@ -39,10 +39,10 @@
#include "insn-config.h"
#include "recog.h"
/* This pass scans the RTL just before the final branch
re-organisation pass. The aim is to identify all places where
there is conditional control flow and to insert code that tracks
any speculative execution of a conditional branch.
/* This pass scans the RTL insns late in the RTL pipeline. The aim is
to identify all places where there is conditional control flow and
to insert code that tracks any speculative execution of a conditional
branch.
To do this we reserve a call-clobbered register (so that it can be
initialized very early in the function prologue) that can then be
@ -131,11 +131,22 @@
carry the tracking state in SP for this period of time unless the
tracker value is needed at that point in time.
We run the pass just before the final branch reorganization pass so
that we can handle most of the conditional branch cases using the
standard edge insertion code. The reorg pass will hopefully clean
things up for afterwards so that the results aren't too
horrible. */
We run the pass while the CFG is still present so that we can handle
most of the conditional branch cases using the standard edge insertion
code. Where possible, we prefer to run the pass just before the final
branch reorganization pass. That pass will then hopefully clean things
up afterwards so that the results aren't too horrible.
However, we must run the pass after all conditional branches have
been inserted. switch_pstate_sm inserts conditional branches for
streaming-compatible code, and so for streaming-compatible functions,
this pass must run after that one.
We handle this by having two copies of the pass: the normal one that
runs before branch reorganization, and a "late" one that runs just
before late_thread_prologue_and_epilogue. The two passes have
mutually exclusive gates, with the normal pass being chosen wherever
possible. */
/* Generate a code sequence to clobber SP if speculating incorreclty. */
static rtx_insn *
@ -315,11 +326,15 @@ aarch64_do_track_speculation ()
needs_tracking = true;
}
if (CALL_P (insn))
if (CALL_P (insn)
|| (NONDEBUG_INSN_P (insn)
&& recog_memoized (insn) >= 0
&& get_attr_is_call (insn) == IS_CALL_YES))
{
bool tailcall
= (SIBLING_CALL_P (insn)
|| find_reg_note (insn, REG_NORETURN, NULL_RTX));
= (CALL_P (insn)
&& (SIBLING_CALL_P (insn)
|| find_reg_note (insn, REG_NORETURN, NULL_RTX)));
/* Tailcalls are like returns, we can eliminate the
transfer between the tracker register and SP if we
@ -461,21 +476,28 @@ const pass_data pass_data_aarch64_track_speculation =
class pass_track_speculation : public rtl_opt_pass
{
public:
pass_track_speculation(gcc::context *ctxt)
: rtl_opt_pass(pass_data_aarch64_track_speculation, ctxt)
{}
public:
pass_track_speculation(gcc::context *ctxt, bool is_late)
: rtl_opt_pass(pass_data_aarch64_track_speculation, ctxt),
is_late (is_late)
{}
/* opt_pass methods: */
virtual bool gate (function *)
{
return aarch64_track_speculation;
return (aarch64_track_speculation
&& (is_late == bool (TARGET_STREAMING_COMPATIBLE)));
}
virtual unsigned int execute (function *)
{
return aarch64_do_track_speculation ();
}
private:
/* Whether this is the late pass that runs before late prologue/epilogue
insertion, or the normal pass that runs before branch reorganization. */
bool is_late;
}; // class pass_track_speculation.
} // anon namespace.
@ -483,5 +505,11 @@ class pass_track_speculation : public rtl_opt_pass
rtl_opt_pass *
make_pass_track_speculation (gcc::context *ctxt)
{
return new pass_track_speculation (ctxt);
return new pass_track_speculation (ctxt, /*is_late=*/false);
}
rtl_opt_pass *
make_pass_late_track_speculation (gcc::context *ctxt)
{
return new pass_track_speculation (ctxt, /*is_late=*/true);
}

View file

@ -2659,6 +2659,27 @@ aarch64_gen_compare_zero_and_branch (rtx_code code, rtx x,
return gen_rtx_SET (pc_rtx, x);
}
/* Return an rtx that branches to LABEL based on the value of bit BITNUM of X.
If CODE is NE, it branches to LABEL when the bit is set; if CODE is EQ,
it branches to LABEL when the bit is clear. */
static rtx
aarch64_gen_test_and_branch (rtx_code code, rtx x, int bitnum,
rtx_code_label *label)
{
auto mode = GET_MODE (x);
if (aarch64_track_speculation)
{
auto mask = gen_int_mode (HOST_WIDE_INT_1U << bitnum, mode);
emit_insn (gen_aarch64_and3nr_compare0 (mode, x, mask));
rtx cc_reg = gen_rtx_REG (CC_NZVmode, CC_REGNUM);
rtx x = gen_rtx_fmt_ee (code, CC_NZVmode, cc_reg, const0_rtx);
return gen_condjump (x, cc_reg, label);
}
return gen_aarch64_tb (code, mode, mode,
x, gen_int_mode (bitnum, mode), label);
}
/* Consider the operation:
OPERANDS[0] = CODE (OPERANDS[1], OPERANDS[2]) + OPERANDS[3]
@ -4881,8 +4902,9 @@ aarch64_guard_switch_pstate_sm (rtx old_svcr, aarch64_feature_flags local_mode)
gcc_assert (local_mode != 0);
auto already_ok_cond = (local_mode & AARCH64_FL_SM_ON ? NE : EQ);
auto *label = gen_label_rtx ();
auto *jump = emit_jump_insn (gen_aarch64_tb (already_ok_cond, DImode, DImode,
old_svcr, const0_rtx, label));
auto branch = aarch64_gen_test_and_branch (already_ok_cond, old_svcr, 0,
label);
auto *jump = emit_jump_insn (branch);
JUMP_LABEL (jump) = label;
return label;
}

View file

@ -439,6 +439,12 @@
(define_enum_attr "arch" "arches" (const_string "any"))
;; Whether a normal INSN in fact contains a call. Sometimes we represent
;; calls to functions that use an ad-hoc ABI as normal insns, both for
;; optimization reasons and to avoid the need to describe the ABI to
;; target-independent code.
(define_attr "is_call" "no,yes" (const_string "no"))
;; [For compatibility with Arm in pipeline models]
;; Attribute that specifies whether or not the instruction touches fp
;; registers.
@ -5395,7 +5401,7 @@
[(set_attr "type" "alus_imm")]
)
(define_insn "*and<mode>3nr_compare0"
(define_insn "@aarch64_and<mode>3nr_compare0"
[(set (reg:CC_NZV CC_REGNUM)
(compare:CC_NZV
(and:GPI (match_operand:GPI 0 "register_operand")

View file

@ -0,0 +1,209 @@
// { dg-options "-O -fomit-frame-pointer -fno-optimize-sibling-calls -funwind-tables -mtrack-speculation" }
// { dg-final { check-function-bodies "**" "" } }
void ns_callee ();
void s_callee () [[arm::streaming]];
void sc_callee () [[arm::streaming_compatible]];
void ns_callee_stack (int, int, int, int, int, int, int, int, int);
struct callbacks {
void (*ns_ptr) ();
void (*s_ptr) () [[arm::streaming]];
void (*sc_ptr) () [[arm::streaming_compatible]];
};
/*
** sc_caller_sme:
** cmp sp, #?0
** csetm x15, ne
** stp x29, x30, \[sp, #?-96\]!
** mov x29, sp
** cntd x16
** str x16, \[sp, #?24\]
** stp d8, d9, \[sp, #?32\]
** stp d10, d11, \[sp, #?48\]
** stp d12, d13, \[sp, #?64\]
** stp d14, d15, \[sp, #?80\]
** mrs x16, svcr
** str x16, \[x29, #?16\]
** ldr x16, \[x29, #?16\]
** tst x16, #?1
** beq [^\n]*
** csel x15, x15, xzr, ne
** smstop sm
** b [^\n]*
** csel x15, x15, xzr, eq
** mov x14, sp
** and x14, x14, x15
** mov sp, x14
** bl ns_callee
** cmp sp, #?0
** csetm x15, ne
** ldr x16, \[x29, #?16\]
** tst x16, #?1
** beq [^\n]*
** csel x15, x15, xzr, ne
** smstart sm
** b [^\n]*
** csel x15, x15, xzr, eq
** ldr x16, \[x29, #?16\]
** tst x16, #?1
** bne [^\n]*
** csel x15, x15, xzr, eq
** smstart sm
** b [^\n]*
** csel x15, x15, xzr, ne
** mov x14, sp
** and x14, x14, x15
** mov sp, x14
** bl s_callee
** cmp sp, #?0
** csetm x15, ne
** ldr x16, \[x29, #?16\]
** tst x16, #?1
** bne [^\n]*
** csel x15, x15, xzr, eq
** smstop sm
** b [^\n]*
** csel x15, x15, xzr, ne
** mov x14, sp
** and x14, x14, x15
** mov sp, x14
** bl sc_callee
** cmp sp, #?0
** csetm x15, ne
** ldp d8, d9, \[sp, #?32\]
** ldp d10, d11, \[sp, #?48\]
** ldp d12, d13, \[sp, #?64\]
** ldp d14, d15, \[sp, #?80\]
** ldp x29, x30, \[sp\], #?96
** mov x14, sp
** and x14, x14, x15
** mov sp, x14
** ret
*/
void
sc_caller_sme () [[arm::streaming_compatible]]
{
ns_callee ();
s_callee ();
sc_callee ();
}
#pragma GCC target "+nosme"
/*
** sc_caller:
** cmp sp, #?0
** csetm x15, ne
** stp x29, x30, \[sp, #?-96\]!
** mov x29, sp
** cntd x16
** str x16, \[sp, #?24\]
** stp d8, d9, \[sp, #?32\]
** stp d10, d11, \[sp, #?48\]
** stp d12, d13, \[sp, #?64\]
** stp d14, d15, \[sp, #?80\]
** mov x14, sp
** and x14, x14, x15
** mov sp, x14
** bl __arm_sme_state
** cmp sp, #?0
** csetm x15, ne
** str x0, \[x29, #?16\]
** ...
** bl sc_callee
** cmp sp, #?0
** csetm x15, ne
** ldp d8, d9, \[sp, #?32\]
** ldp d10, d11, \[sp, #?48\]
** ldp d12, d13, \[sp, #?64\]
** ldp d14, d15, \[sp, #?80\]
** ldp x29, x30, \[sp\], #?96
** mov x14, sp
** and x14, x14, x15
** mov sp, x14
** ret
*/
void
sc_caller () [[arm::streaming_compatible]]
{
ns_callee ();
sc_callee ();
}
/*
** sc_caller_x0:
** ...
** mov x10, x0
** mov x14, sp
** and x14, x14, x15
** mov sp, x14
** bl __arm_sme_state
** ...
** str wzr, \[x10\]
** ...
*/
void
sc_caller_x0 (int *ptr) [[arm::streaming_compatible]]
{
*ptr = 0;
ns_callee ();
sc_callee ();
}
/*
** sc_caller_x1:
** ...
** mov x10, x0
** mov x11, x1
** mov x14, sp
** and x14, x14, x15
** mov sp, x14
** bl __arm_sme_state
** ...
** str w11, \[x10\]
** ...
*/
void
sc_caller_x1 (int *ptr, int a) [[arm::streaming_compatible]]
{
*ptr = a;
ns_callee ();
sc_callee ();
}
/*
** sc_caller_stack:
** cmp sp, #?0
** csetm x15, ne
** sub sp, sp, #112
** stp x29, x30, \[sp, #?16\]
** add x29, sp, #?16
** ...
** stp d8, d9, \[sp, #?48\]
** ...
** bl __arm_sme_state
** cmp sp, #?0
** csetm x15, ne
** str x0, \[x29, #?16\]
** ...
** bl ns_callee_stack
** cmp sp, #?0
** csetm x15, ne
** ldr x16, \[x29, #?16\]
** tst x16, #?1
** beq [^\n]*
** csel x15, x15, xzr, ne
** smstart sm
** ...
*/
void
sc_caller_stack () [[arm::streaming_compatible]]
{
ns_callee_stack (0, 0, 0, 0, 0, 0, 0, 0, 0);
}
/* { dg-final { scan-assembler {sc_caller_sme:(?:(?!ret).)*\.cfi_offset 46, -72\n} } } */
/* { dg-final { scan-assembler {sc_caller:(?:(?!ret).)*\.cfi_offset 46, -72\n} } } */