Skip to content

Commit

Permalink
Shader execution tests for f16 built-in log, log2 (gpuweb#2880)
Browse files Browse the repository at this point in the history
This PR add execution tests for f16 built-in log and log2.

Issue: gpuweb#1248
  • Loading branch information
jzm-intel authored Aug 1, 2023
1 parent a8d2a82 commit fa8f784
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 46 deletions.
108 changes: 78 additions & 30 deletions src/unittests/floating_point.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2968,61 +2968,109 @@ g.test('lengthIntervalScalar_f32')
);
});

g.test('logInterval_f32')
.paramsSubcasesOnly<ScalarToIntervalCase>(
// prettier-ignore
[
{ input: -1, expected: kAnyBounds },
{ input: 0, expected: kAnyBounds },
{ input: 1, expected: 0 },
{ input: kValue.f32.positive.e, expected: [kMinusOneULPFunctions['f32'](1), 1] },
{ input: kValue.f32.positive.max, expected: [kMinusOneULPFunctions['f32'](reinterpretU32AsF32(0x42b17218)), reinterpretU32AsF32(0x42b17218)] }, // ~88.72...
]
// prettier-ignore
const kLogIntervalCases = {
f32: [
// kValue.f32.positive.e is 0x402DF854 = 2.7182817459106445,
// log(0x402DF854) = 0.99999996963214000677592342891704 rounded to f32 0x3F7FFFFF or 0x3F800000 = 1.0
{ input: kValue.f32.positive.e, expected: [kMinusOneULPFunctions['f32'](1.0), 1.0] },
// kValue.f32.positive.max is 0x7F7FFFFF = 3.4028234663852886e+38,
// log(0x7F7FFFFF) = 88.72283905206835305421152826479 rounded to f32 0x42B17217 or 0x42B17218.
{ input: kValue.f32.positive.max, expected: [kMinusOneULPFunctions['f32'](reinterpretU32AsF32(0x42b17218)), reinterpretU32AsF32(0x42b17218)] },
] as ScalarToIntervalCase[],
f16: [
// kValue.f16.positive.e is 0x416F = 2.716796875,
// log(0x416F) = 0.99945356688393512460279716546501 rounded to f16 0x3BFE or 0x3BFF.
{ input: kValue.f16.positive.e, expected: [reinterpretU16AsF16(0x3bfe), reinterpretU16AsF16(0x3bff)] },
// kValue.f16.positive.max is 0x7BFF = 65504,
// log(0x7BFF) = 11.089866488461016076210728979771 rounded to f16 0x498B or 0x498C.
{ input: kValue.f16.positive.max, expected: [reinterpretU16AsF16(0x498b), reinterpretU16AsF16(0x498c)] },
] as ScalarToIntervalCase[],
} as const;

g.test('logInterval')
.params(u =>
u
.combine('trait', ['f32', 'f16'] as const)
.beginSubcases()
.expandWithParams<ScalarToIntervalCase>(p => {
// prettier-ignore
return [
{ input: -1, expected: kAnyBounds },
{ input: 0, expected: kAnyBounds },
{ input: 1, expected: 0 },
...kLogIntervalCases[p.trait],
];
})
)
.fn(t => {
const trait = FP[t.params.trait];
const abs_error = t.params.trait === 'f32' ? 2 ** -21 : 2 ** -7;
const error = (n: number): number => {
if (t.params.input >= 0.5 && t.params.input <= 2.0) {
return 2 ** -21;
return abs_error;
}
return 3 * oneULPF32(n);
return 3 * trait.oneULP(n);
};

t.params.expected = applyError(t.params.expected, error);
const expected = FP.f32.toInterval(t.params.expected);
const expected = trait.toInterval(t.params.expected);

const got = FP.f32.logInterval(t.params.input);
const got = trait.logInterval(t.params.input);
t.expect(
objectEquals(expected, got),
`f32.logInterval(${t.params.input}) returned ${got}. Expected ${expected}`
`${t.params.trait}.logInterval(${t.params.input}) returned ${got}. Expected ${expected}`
);
});

g.test('log2Interval_f32')
.paramsSubcasesOnly<ScalarToIntervalCase>(
// prettier-ignore
[
{ input: -1, expected: kAnyBounds },
{ input: 0, expected: kAnyBounds },
{ input: 1, expected: 0 },
{ input: 2, expected: 1 },
{ input: kValue.f32.positive.max, expected: [kMinusOneULPFunctions['f32'](128), 128] },
]
// prettier-ignore
const kLog2IntervalCases = {
f32: [
// kValue.f32.positive.max is 0x7F7FFFFF = 3.4028234663852886e+38,
// log2(0x7F7FFFFF) = 127.99999991400867200665269600978 rounded to f32 0x42FFFFFF or 0x43000000 = 128.0
{ input: kValue.f32.positive.max, expected: [kMinusOneULPFunctions['f32'](128.0), 128.0] },
] as ScalarToIntervalCase[],
f16: [
// kValue.f16.positive.max is 0x7BFF = 65504,
// log2(0x7BFF) = 15.999295387023410627258428389903 rounded to f16 0x4BFF or 0x4C00 = 16.0
{ input: kValue.f16.positive.max, expected: [kMinusOneULPFunctions['f16'](16.0), 16.0] },
] as ScalarToIntervalCase[],
} as const;

g.test('log2Interval')
.params(u =>
u
.combine('trait', ['f32', 'f16'] as const)
.beginSubcases()
.expandWithParams<ScalarToIntervalCase>(p => {
// prettier-ignore
return [
{ input: -1, expected: kAnyBounds },
{ input: 0, expected: kAnyBounds },
{ input: 1, expected: 0 },
{ input: 2, expected: 1 },
{ input: 16, expected: 4 },
...kLog2IntervalCases[p.trait],
];
})
)
.fn(t => {
const trait = FP[t.params.trait];
const abs_error = t.params.trait === 'f32' ? 2 ** -21 : 2 ** -7;
const error = (n: number): number => {
if (t.params.input >= 0.5 && t.params.input <= 2.0) {
return 2 ** -21;
return abs_error;
}
return 3 * oneULPF32(n);
return 3 * trait.oneULP(n);
};

t.params.expected = applyError(t.params.expected, error);
const expected = FP.f32.toInterval(t.params.expected);
const expected = trait.toInterval(t.params.expected);

const got = FP.f32.log2Interval(t.params.input);
const got = trait.log2Interval(t.params.input);
t.expect(
objectEquals(expected, got),
`f32.log2Interval(${t.params.input}) returned ${got}. Expected ${expected}`
`${t.params.trait}.log2Interval(${t.params.input}) returned ${got}. Expected ${expected}`
);
});

Expand Down
30 changes: 24 additions & 6 deletions src/webgpu/shader/execution/expression/call/builtin/log.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ Returns the natural logarithm of e. Component-wise when T is a vector.
import { makeTestGroup } from '../../../../../../common/framework/test_group.js';
import { GPUTest } from '../../../../../gpu_test.js';
import { kValue } from '../../../../../util/constants.js';
import { TypeF32 } from '../../../../../util/conversion.js';
import { TypeF32, TypeF16 } from '../../../../../util/conversion.js';
import { FP } from '../../../../../util/floating_point.js';
import { biasedRange, fullF32Range, linearRange } from '../../../../../util/math.js';
import { biasedRange, fullF32Range, fullF16Range, linearRange } from '../../../../../util/math.js';
import { makeCaseCache } from '../../case_cache.js';
import { allInputSources, run } from '../../expression.js';

Expand All @@ -21,19 +21,31 @@ import { builtin } from './builtin.js';
export const g = makeTestGroup(GPUTest);

// log's accuracy is defined in three regions { [0, 0.5), [0.5, 2.0], (2.0, +∞] }
const inputs = [
const f32_inputs = [
...linearRange(kValue.f32.positive.min, 0.5, 20),
...linearRange(0.5, 2.0, 20),
...biasedRange(2.0, 2 ** 32, 1000),
...fullF32Range(),
];
const f16_inputs = [
...linearRange(kValue.f16.positive.min, 0.5, 20),
...linearRange(0.5, 2.0, 20),
...biasedRange(2.0, 2 ** 32, 1000),
...fullF16Range(),
];

export const d = makeCaseCache('log', {
f32_const: () => {
return FP.f32.generateScalarToIntervalCases(inputs, 'finite', FP.f32.logInterval);
return FP.f32.generateScalarToIntervalCases(f32_inputs, 'finite', FP.f32.logInterval);
},
f32_non_const: () => {
return FP.f32.generateScalarToIntervalCases(inputs, 'unfiltered', FP.f32.logInterval);
return FP.f32.generateScalarToIntervalCases(f32_inputs, 'unfiltered', FP.f32.logInterval);
},
f16_const: () => {
return FP.f16.generateScalarToIntervalCases(f16_inputs, 'finite', FP.f16.logInterval);
},
f16_non_const: () => {
return FP.f16.generateScalarToIntervalCases(f16_inputs, 'unfiltered', FP.f16.logInterval);
},
});

Expand Down Expand Up @@ -68,4 +80,10 @@ g.test('f16')
.params(u =>
u.combine('inputSource', allInputSources).combine('vectorize', [undefined, 2, 3, 4] as const)
)
.unimplemented();
.beforeAllSubcases(t => {
t.selectDeviceOrSkipTestCase('shader-f16');
})
.fn(async t => {
const cases = await d.get(t.params.inputSource === 'const' ? 'f16_const' : 'f16_non_const');
await run(t, builtin('log'), [TypeF16], TypeF16, t.params, cases);
});
30 changes: 24 additions & 6 deletions src/webgpu/shader/execution/expression/call/builtin/log2.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ Returns the base-2 logarithm of e. Component-wise when T is a vector.
import { makeTestGroup } from '../../../../../../common/framework/test_group.js';
import { GPUTest } from '../../../../../gpu_test.js';
import { kValue } from '../../../../../util/constants.js';
import { TypeF32 } from '../../../../../util/conversion.js';
import { TypeF32, TypeF16 } from '../../../../../util/conversion.js';
import { FP } from '../../../../../util/floating_point.js';
import { biasedRange, fullF32Range, linearRange } from '../../../../../util/math.js';
import { biasedRange, fullF32Range, fullF16Range, linearRange } from '../../../../../util/math.js';
import { makeCaseCache } from '../../case_cache.js';
import { allInputSources, run } from '../../expression.js';

Expand All @@ -21,19 +21,31 @@ import { builtin } from './builtin.js';
export const g = makeTestGroup(GPUTest);

// log2's accuracy is defined in three regions { [0, 0.5), [0.5, 2.0], (2.0, +∞] }
const inputs = [
const f32_inputs = [
...linearRange(kValue.f32.positive.min, 0.5, 20),
...linearRange(0.5, 2.0, 20),
...biasedRange(2.0, 2 ** 32, 1000),
...fullF32Range(),
];
const f16_inputs = [
...linearRange(kValue.f16.positive.min, 0.5, 20),
...linearRange(0.5, 2.0, 20),
...biasedRange(2.0, 2 ** 32, 1000),
...fullF16Range(),
];

export const d = makeCaseCache('log2', {
f32_const: () => {
return FP.f32.generateScalarToIntervalCases(inputs, 'finite', FP.f32.log2Interval);
return FP.f32.generateScalarToIntervalCases(f32_inputs, 'finite', FP.f32.log2Interval);
},
f32_non_const: () => {
return FP.f32.generateScalarToIntervalCases(inputs, 'unfiltered', FP.f32.log2Interval);
return FP.f32.generateScalarToIntervalCases(f32_inputs, 'unfiltered', FP.f32.log2Interval);
},
f16_const: () => {
return FP.f16.generateScalarToIntervalCases(f16_inputs, 'finite', FP.f16.log2Interval);
},
f16_non_const: () => {
return FP.f16.generateScalarToIntervalCases(f16_inputs, 'unfiltered', FP.f16.log2Interval);
},
});

Expand Down Expand Up @@ -68,4 +80,10 @@ g.test('f16')
.params(u =>
u.combine('inputSource', allInputSources).combine('vectorize', [undefined, 2, 3, 4] as const)
)
.unimplemented();
.beforeAllSubcases(t => {
t.selectDeviceOrSkipTestCase('shader-f16');
})
.fn(async t => {
const cases = await d.get(t.params.inputSource === 'const' ? 'f16_const' : 'f16_non_const');
await run(t, builtin('log2'), [TypeF16], TypeF16, t.params, cases);
});
12 changes: 8 additions & 4 deletions src/webgpu/util/floating_point.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3437,8 +3437,10 @@ export abstract class FPTraits {
impl: this.limitScalarToIntervalDomain(
this.constants().greaterThanZeroInterval,
(n: number): FPInterval => {
assert(this.kind === 'f32' || this.kind === 'f16');
const abs_error = this.kind === 'f32' ? 2 ** -21 : 2 ** -7;
if (n >= 0.5 && n <= 2.0) {
return this.absoluteErrorInterval(Math.log(n), 2 ** -21);
return this.absoluteErrorInterval(Math.log(n), abs_error);
}
return this.ulpInterval(Math.log(n), 3);
}
Expand All @@ -3456,8 +3458,10 @@ export abstract class FPTraits {
impl: this.limitScalarToIntervalDomain(
this.constants().greaterThanZeroInterval,
(n: number): FPInterval => {
assert(this.kind === 'f32' || this.kind === 'f16');
const abs_error = this.kind === 'f32' ? 2 ** -21 : 2 ** -7;
if (n >= 0.5 && n <= 2.0) {
return this.absoluteErrorInterval(Math.log2(n), 2 ** -21);
return this.absoluteErrorInterval(Math.log2(n), abs_error);
}
return this.ulpInterval(Math.log2(n), 3);
}
Expand Down Expand Up @@ -4970,8 +4974,8 @@ class F16Traits extends FPTraits {
public readonly inverseSqrtInterval = this.inverseSqrtIntervalImpl.bind(this);
public readonly ldexpInterval = this.unimplementedScalarPairToInterval.bind(this);
public readonly lengthInterval = this.unimplementedLength.bind(this);
public readonly logInterval = this.unimplementedScalarToInterval.bind(this);
public readonly log2Interval = this.unimplementedScalarToInterval.bind(this);
public readonly logInterval = this.logIntervalImpl.bind(this);
public readonly log2Interval = this.log2IntervalImpl.bind(this);
public readonly maxInterval = this.maxIntervalImpl.bind(this);
public readonly minInterval = this.minIntervalImpl.bind(this);
public readonly mixImpreciseInterval = this.unimplementedScalarTripleToInterval.bind(this);
Expand Down

0 comments on commit fa8f784

Please sign in to comment.