openmp: Add OpenMP _BitInt support [PR113409]

The following patch adds support for _BitInt iterators of OpenMP canonical
loops (with the preexisting limitation that when not using compile time
static scheduling the iterators in the library are at most unsigned long long
or signed long, so one can't in the runtime/dynamic/guided etc. cases iterate
more than what those types can represent, like is the case of e.g. __int128
iterators too) and the testcase also covers linear/reduction clauses for them.

2024-01-17  Jakub Jelinek  <jakub@redhat.com>

	PR middle-end/113409
	* omp-general.cc (omp_adjust_for_condition): Handle BITINT_TYPE like
	INTEGER_TYPE.
	(omp_extract_for_data): Use build_bitint_type rather than
	build_nonstandard_integer_type if either iter_type or loop->v type
	is BITINT_TYPE.
	* omp-expand.cc (expand_omp_for_generic,
	expand_omp_taskloop_for_outer, expand_omp_taskloop_for_inner): Handle
	BITINT_TYPE like INTEGER_TYPE.

	* testsuite/libgomp.c/bitint-1.c: New test.
This commit is contained in:
Jakub Jelinek 2024-01-17 10:47:31 +01:00
parent 557dbbac8e
commit c8f1045679
3 changed files with 84 additions and 8 deletions

View file

@ -4075,7 +4075,7 @@ expand_omp_for_generic (struct omp_region *region,
/* See if we need to bias by LLONG_MIN. */
if (fd->iter_type == long_long_unsigned_type_node
&& TREE_CODE (type) == INTEGER_TYPE
&& (TREE_CODE (type) == INTEGER_TYPE || TREE_CODE (type) == BITINT_TYPE)
&& !TYPE_UNSIGNED (type)
&& fd->ordered == 0)
{
@ -7191,7 +7191,7 @@ expand_omp_taskloop_for_outer (struct omp_region *region,
/* See if we need to bias by LLONG_MIN. */
if (fd->iter_type == long_long_unsigned_type_node
&& TREE_CODE (type) == INTEGER_TYPE
&& (TREE_CODE (type) == INTEGER_TYPE || TREE_CODE (type) == BITINT_TYPE)
&& !TYPE_UNSIGNED (type))
{
tree n1, n2;
@ -7352,7 +7352,7 @@ expand_omp_taskloop_for_inner (struct omp_region *region,
/* See if we need to bias by LLONG_MIN. */
if (fd->iter_type == long_long_unsigned_type_node
&& TREE_CODE (type) == INTEGER_TYPE
&& (TREE_CODE (type) == INTEGER_TYPE || TREE_CODE (type) == BITINT_TYPE)
&& !TYPE_UNSIGNED (type))
{
tree n1, n2;

View file

@ -115,7 +115,8 @@ omp_adjust_for_condition (location_t loc, enum tree_code *cond_code, tree *n2,
case NE_EXPR:
gcc_assert (TREE_CODE (step) == INTEGER_CST);
if (TREE_CODE (TREE_TYPE (v)) == INTEGER_TYPE)
if (TREE_CODE (TREE_TYPE (v)) == INTEGER_TYPE
|| TREE_CODE (TREE_TYPE (v)) == BITINT_TYPE)
{
if (integer_onep (step))
*cond_code = LT_EXPR;
@ -409,6 +410,7 @@ omp_extract_for_data (gomp_for *for_stmt, struct omp_for_data *fd,
loop->v = gimple_omp_for_index (for_stmt, i);
gcc_assert (SSA_VAR_P (loop->v));
gcc_assert (TREE_CODE (TREE_TYPE (loop->v)) == INTEGER_TYPE
|| TREE_CODE (TREE_TYPE (loop->v)) == BITINT_TYPE
|| TREE_CODE (TREE_TYPE (loop->v)) == POINTER_TYPE);
var = TREE_CODE (loop->v) == SSA_NAME ? SSA_NAME_VAR (loop->v) : loop->v;
loop->n1 = gimple_omp_for_initial (for_stmt, i);
@ -479,9 +481,17 @@ omp_extract_for_data (gomp_for *for_stmt, struct omp_for_data *fd,
else if (i == 0
|| TYPE_PRECISION (iter_type)
< TYPE_PRECISION (TREE_TYPE (loop->v)))
iter_type
= build_nonstandard_integer_type
(TYPE_PRECISION (TREE_TYPE (loop->v)), 1);
{
if (TREE_CODE (iter_type) == BITINT_TYPE
|| TREE_CODE (TREE_TYPE (loop->v)) == BITINT_TYPE)
iter_type
= build_bitint_type (TYPE_PRECISION (TREE_TYPE (loop->v)),
1);
else
iter_type
= build_nonstandard_integer_type
(TYPE_PRECISION (TREE_TYPE (loop->v)), 1);
}
}
else if (iter_type != long_long_unsigned_type_node)
{
@ -747,7 +757,8 @@ omp_extract_for_data (gomp_for *for_stmt, struct omp_for_data *fd,
if (t && integer_zerop (t))
count = build_zero_cst (long_long_unsigned_type_node);
else if ((i == 0 || count != NULL_TREE)
&& TREE_CODE (TREE_TYPE (loop->v)) == INTEGER_TYPE
&& (TREE_CODE (TREE_TYPE (loop->v)) == INTEGER_TYPE
|| TREE_CODE (TREE_TYPE (loop->v)) == BITINT_TYPE)
&& TREE_CONSTANT (loop->n1)
&& TREE_CONSTANT (loop->n2)
&& TREE_CODE (loop->step) == INTEGER_CST)

View file

@ -0,0 +1,65 @@
/* PR middle-end/113409 */
/* { dg-do run { target bitint } } */
extern void abort (void);
#if __BITINT_MAXWIDTH__ >= 1023
typedef _BitInt(931) B931;
typedef _BitInt(1023) B1023;
#else
typedef _BitInt(31) B931;
typedef _BitInt(63) B1023;
#endif
__attribute__((noipa)) B931
bar (B931 x)
{
return x;
}
B931
foo (B931 x)
{
B931 r = 0;
B1023 l = 56wb;
#pragma omp parallel for reduction(+: r) linear(l : 3wb)
for (B931 i = 0; i < x; ++i)
{
r += bar (i);
l += 3wb;
}
if (l != (B1023) 56wb + x * 3wb)
abort ();
return r;
}
B931
baz (B931 a, B931 b, B931 c, B931 d, B931 e, B931 f)
{
B931 r = 0;
#pragma omp parallel for collapse (2wb) reduction(+: r)
for (B931 i = a; i < b; i += c)
for (B931 j = d; j > e; j += f)
{
r += (j - d) / f;
__builtin_printf ("%d\n", (int) r);
}
return r;
}
int
main ()
{
if (foo (16wb) != (B931) 15wb * 16wb / 2
|| foo (256wb) != (B931) 255wb * 256wb / 2)
abort ();
#if __BITINT_MAXWIDTH__ >= 1023
if (baz (5019676379303764570412381742937286053482001129028025397398691108125646744814606405323608429353439158482254231750681261083217232780938592007150824765654203477280876662295642053702075485153212701225737143207062700602509893062044376997132415613866154761073993220684129908568716699977wb,
5019676379303764570412381742937286053482001129028025397398691108125646744814606405323608429353439158482254231750681261083217232780938592007150824765654203477280876662295642053702075485153212701225737143207062700602509893062044376997132415613866154761074023903103954348393149593648wb,
398472984732984732894723wb,
5145599438319070078334010989312672300490893251953772234670751860213136881221517063143096309285807356778798661066289865489661268190588670564647904159660341525674865064477335008915374460378741763629714814990575971883514175167056160470289039998140910732382754821232566561860399556131wb,
5145599438319070078334010989312672300490893251953772234670751860213136881221517063143096309285807356778798661066289865489661268190588670564647904159660341525674865064477335008915374460378741763629714814990575971883514175167056160470289039995725068563774878643356388697468035164019wb,
-89475635874365784365784365347865347856wb) != (B931) 26wb * 27wb / 2wb * 77wb)
abort ();
#endif
}