Skip to content

Commit

Permalink
Do some basic validation of Target Features (#7986) (#7987)
Browse files Browse the repository at this point in the history
* Do some basic validation of Target Features (#7986)

* Update Target.cpp

* Update Target.cpp

* Fixes

* Update Target.cpp

* Improve error messaging.

* format

* Update Target.cpp
  • Loading branch information
steven-johnson authored Dec 8, 2023
1 parent 9c099c2 commit 357e646
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 8 deletions.
5 changes: 1 addition & 4 deletions python_bindings/test/correctness/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,14 @@ def test_target():
32,
[
hl.TargetFeature.JIT,
hl.TargetFeature.SSE41,
hl.TargetFeature.AVX,
hl.TargetFeature.AVX2,
hl.TargetFeature.CUDA,
hl.TargetFeature.OpenCL,
hl.TargetFeature.OpenGLCompute,
hl.TargetFeature.Debug,
],
)
ts = t1.to_string()
assert ts == "arm-32-android-avx-avx2-cuda-debug-jit-opencl-openglcompute-sse41"
assert ts == "arm-32-android-cuda-debug-jit-opencl-openglcompute"
assert hl.Target.validate_target_string(ts)

# Expected failures:
Expand Down
83 changes: 83 additions & 0 deletions src/Target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -785,8 +785,90 @@ void bad_target_string(const std::string &target) {
<< "On this platform, the host target is: " << get_host_target().to_string() << "\n";
}

void do_check_bad(const Target &t, const std::initializer_list<Target::Feature> &v) {
for (Target::Feature f : v) {
user_assert(!t.has_feature(f))
<< "Target feature " << Target::feature_to_name(f)
<< " is incompatible with the Target's architecture. (" << t << ")\n";
}
}

} // namespace

void Target::validate_features() const {
// Note that the features don't have to be exhaustive, but enough to avoid obvious mistakes is good.
if (arch == X86) {
do_check_bad(*this, {
ARMDotProd,
ARMFp16,
ARMv7s,
ARMv81a,
NoNEON,
POWER_ARCH_2_07,
RVV,
SVE,
SVE2,
VSX,
WasmBulkMemory,
WasmMvpOnly,
WasmSimd128,
WasmThreads,
});
} else if (arch == ARM) {
do_check_bad(*this, {
AVX,
AVX2,
AVX512,
AVX512_Cannonlake,
AVX512_KNL,
AVX512_SapphireRapids,
AVX512_Skylake,
AVX512_Zen4,
F16C,
FMA,
FMA4,
POWER_ARCH_2_07,
RVV,
SSE41,
VSX,
WasmBulkMemory,
WasmMvpOnly,
WasmSimd128,
WasmThreads,
});
} else if (arch == WebAssembly) {
do_check_bad(*this, {
ARMDotProd,
ARMFp16,
ARMv7s,
ARMv81a,
AVX,
AVX2,
AVX512,
AVX512_Cannonlake,
AVX512_KNL,
AVX512_SapphireRapids,
AVX512_Skylake,
AVX512_Zen4,
F16C,
FMA,
FMA4,
HVX_128,
HVX_128,
HVX_v62,
HVX_v65,
HVX_v66,
NoNEON,
POWER_ARCH_2_07,
RVV,
SSE41,
SVE,
SVE2,
VSX,
});
}
}

Target::Target(const std::string &target) {
Target host = get_host_target();

Expand All @@ -798,6 +880,7 @@ Target::Target(const std::string &target) {
bad_target_string(target);
}
}
validate_features();
}

Target::Target(const char *s)
Expand Down
6 changes: 6 additions & 0 deletions src/Target.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ struct Target {
for (const auto &f : initial_features) {
set_feature(f);
}
validate_features();
}

Target(OS o, Arch a, int b, const std::vector<Feature> &initial_features = std::vector<Feature>())
Expand Down Expand Up @@ -357,6 +358,11 @@ struct Target {
private:
/** A bitmask that stores the active features. */
std::bitset<FeatureEnd> features;

/** Attempt to validate that all features set are sensible for the base Target.
* This is *not* guaranteed to get all invalid combinations, but is intended
* to catch at least the most common (e.g., setting arm-specific features on x86). */
void validate_features() const;
};

/** Return the target corresponding to the host machine. */
Expand Down
7 changes: 3 additions & 4 deletions test/correctness/target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,10 @@ int main(int argc, char **argv) {

// Full specification round-trip, crazy features
t1 = Target(Target::Android, Target::ARM, 32,
{Target::JIT, Target::SSE41, Target::AVX, Target::AVX2,
Target::CUDA, Target::OpenCL, Target::OpenGLCompute,
Target::Debug});
{Target::JIT, Target::CUDA, Target::OpenCL,
Target::OpenGLCompute, Target::Debug});
ts = t1.to_string();
if (ts != "arm-32-android-avx-avx2-cuda-debug-jit-opencl-openglcompute-sse41") {
if (ts != "arm-32-android-cuda-debug-jit-opencl-openglcompute") {
printf("to_string failure: %s\n", ts.c_str());
return 1;
}
Expand Down

0 comments on commit 357e646

Please sign in to comment.